| OLD | NEW |
| 1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/websockets/websocket_basic_handshake_stream.h" | 5 #include "net/websockets/websocket_basic_handshake_stream.h" |
| 6 | 6 |
| 7 #include <algorithm> | 7 #include <algorithm> |
| 8 #include <iterator> | 8 #include <iterator> |
| 9 #include <set> | 9 #include <set> |
| 10 #include <string> | 10 #include <string> |
| (...skipping 17 matching lines...) Expand all Loading... |
| 28 #include "net/base/io_buffer.h" | 28 #include "net/base/io_buffer.h" |
| 29 #include "net/http/http_request_headers.h" | 29 #include "net/http/http_request_headers.h" |
| 30 #include "net/http/http_request_info.h" | 30 #include "net/http/http_request_info.h" |
| 31 #include "net/http/http_response_body_drainer.h" | 31 #include "net/http/http_response_body_drainer.h" |
| 32 #include "net/http/http_response_headers.h" | 32 #include "net/http/http_response_headers.h" |
| 33 #include "net/http/http_status_code.h" | 33 #include "net/http/http_status_code.h" |
| 34 #include "net/http/http_stream_parser.h" | 34 #include "net/http/http_stream_parser.h" |
| 35 #include "net/socket/client_socket_handle.h" | 35 #include "net/socket/client_socket_handle.h" |
| 36 #include "net/socket/websocket_transport_client_socket_pool.h" | 36 #include "net/socket/websocket_transport_client_socket_pool.h" |
| 37 #include "net/websockets/websocket_basic_stream.h" | 37 #include "net/websockets/websocket_basic_stream.h" |
| 38 #include "net/websockets/websocket_deflate_parameters.h" |
| 38 #include "net/websockets/websocket_deflate_predictor.h" | 39 #include "net/websockets/websocket_deflate_predictor.h" |
| 39 #include "net/websockets/websocket_deflate_predictor_impl.h" | 40 #include "net/websockets/websocket_deflate_predictor_impl.h" |
| 40 #include "net/websockets/websocket_deflate_stream.h" | 41 #include "net/websockets/websocket_deflate_stream.h" |
| 41 #include "net/websockets/websocket_deflater.h" | 42 #include "net/websockets/websocket_deflater.h" |
| 42 #include "net/websockets/websocket_extension_parser.h" | 43 #include "net/websockets/websocket_extension_parser.h" |
| 43 #include "net/websockets/websocket_handshake_challenge.h" | 44 #include "net/websockets/websocket_handshake_challenge.h" |
| 44 #include "net/websockets/websocket_handshake_constants.h" | 45 #include "net/websockets/websocket_handshake_constants.h" |
| 45 #include "net/websockets/websocket_handshake_request_info.h" | 46 #include "net/websockets/websocket_handshake_request_info.h" |
| 46 #include "net/websockets/websocket_handshake_response_info.h" | 47 #include "net/websockets/websocket_handshake_response_info.h" |
| 47 #include "net/websockets/websocket_stream.h" | 48 #include "net/websockets/websocket_stream.h" |
| 48 | 49 |
| 49 namespace net { | 50 namespace net { |
| 50 | 51 |
| 51 namespace { | 52 namespace { |
| 52 | 53 |
| 53 const char kConnectionErrorStatusLine[] = "HTTP/1.1 503 Connection Error"; | 54 const char kConnectionErrorStatusLine[] = "HTTP/1.1 503 Connection Error"; |
| 54 | 55 |
| 55 } // namespace | 56 } // namespace |
| 56 | 57 |
| 57 // TODO(ricea): If more extensions are added, replace this with a more general | 58 // TODO(ricea): If more extensions are added, replace this with a more general |
| 58 // mechanism. | 59 // mechanism. |
| 59 struct WebSocketExtensionParams { | 60 struct WebSocketExtensionParams { |
| 60 WebSocketExtensionParams() | 61 bool deflate_enabled = false; |
| 61 : deflate_enabled(false), | 62 WebSocketDeflateParameters deflate_parameters; |
| 62 client_window_bits(15), | |
| 63 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {} | |
| 64 | |
| 65 bool deflate_enabled; | |
| 66 int client_window_bits; | |
| 67 WebSocketDeflater::ContextTakeOverMode deflate_mode; | |
| 68 }; | 63 }; |
| 69 | 64 |
| 70 namespace { | 65 namespace { |
| 71 | 66 |
| 72 enum GetHeaderResult { | 67 enum GetHeaderResult { |
| 73 GET_HEADER_OK, | 68 GET_HEADER_OK, |
| 74 GET_HEADER_MISSING, | 69 GET_HEADER_MISSING, |
| 75 GET_HEADER_MULTIPLE, | 70 GET_HEADER_MULTIPLE, |
| 76 }; | 71 }; |
| 77 | 72 |
| (...skipping 150 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 228 } else if (requested_sub_protocols.size() > 0 && count == 0) { | 223 } else if (requested_sub_protocols.size() > 0 && count == 0) { |
| 229 *failure_message = | 224 *failure_message = |
| 230 "Sent non-empty 'Sec-WebSocket-Protocol' header " | 225 "Sent non-empty 'Sec-WebSocket-Protocol' header " |
| 231 "but no response was received"; | 226 "but no response was received"; |
| 232 return false; | 227 return false; |
| 233 } | 228 } |
| 234 *sub_protocol = value; | 229 *sub_protocol = value; |
| 235 return true; | 230 return true; |
| 236 } | 231 } |
| 237 | 232 |
| 238 bool DeflateError(std::string* message, const base::StringPiece& piece) { | |
| 239 *message = "Error in permessage-deflate: "; | |
| 240 piece.AppendToString(message); | |
| 241 return false; | |
| 242 } | |
| 243 | |
| 244 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, | |
| 245 std::string* failure_message, | |
| 246 WebSocketExtensionParams* params) { | |
| 247 static const char kClientPrefix[] = "client_"; | |
| 248 static const char kServerPrefix[] = "server_"; | |
| 249 static const char kNoContextTakeover[] = "no_context_takeover"; | |
| 250 static const char kMaxWindowBits[] = "max_window_bits"; | |
| 251 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; | |
| 252 static_assert(kPrefixLen == arraysize(kServerPrefix) - 1, | |
| 253 "the strings server and client must be the same length"); | |
| 254 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; | |
| 255 | |
| 256 DCHECK_EQ("permessage-deflate", extension.name()); | |
| 257 const ParameterVector& parameters = extension.parameters(); | |
| 258 std::set<std::string> seen_names; | |
| 259 for (ParameterVector::const_iterator it = parameters.begin(); | |
| 260 it != parameters.end(); ++it) { | |
| 261 const std::string& name = it->name(); | |
| 262 if (seen_names.count(name) != 0) { | |
| 263 return DeflateError( | |
| 264 failure_message, | |
| 265 "Received duplicate permessage-deflate extension parameter " + name); | |
| 266 } | |
| 267 seen_names.insert(name); | |
| 268 const std::string client_or_server(name, 0, kPrefixLen); | |
| 269 const bool is_client = (client_or_server == kClientPrefix); | |
| 270 if (!is_client && client_or_server != kServerPrefix) { | |
| 271 return DeflateError( | |
| 272 failure_message, | |
| 273 "Received an unexpected permessage-deflate extension parameter"); | |
| 274 } | |
| 275 const std::string rest(name, kPrefixLen); | |
| 276 if (rest == kNoContextTakeover) { | |
| 277 if (it->HasValue()) { | |
| 278 return DeflateError(failure_message, | |
| 279 "Received invalid " + name + " parameter"); | |
| 280 } | |
| 281 if (is_client) | |
| 282 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; | |
| 283 } else if (rest == kMaxWindowBits) { | |
| 284 if (!it->HasValue()) | |
| 285 return DeflateError(failure_message, name + " must have value"); | |
| 286 int bits = 0; | |
| 287 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || | |
| 288 it->value()[0] == '0' || | |
| 289 it->value().find_first_not_of("0123456789") != std::string::npos) { | |
| 290 return DeflateError(failure_message, | |
| 291 "Received invalid " + name + " parameter"); | |
| 292 } | |
| 293 if (is_client) | |
| 294 params->client_window_bits = bits; | |
| 295 } else { | |
| 296 return DeflateError( | |
| 297 failure_message, | |
| 298 "Received an unexpected permessage-deflate extension parameter"); | |
| 299 } | |
| 300 } | |
| 301 params->deflate_enabled = true; | |
| 302 return true; | |
| 303 } | |
| 304 | |
| 305 bool ValidateExtensions(const HttpResponseHeaders* headers, | 233 bool ValidateExtensions(const HttpResponseHeaders* headers, |
| 306 std::string* accepted_extensions_descriptor, | 234 std::string* accepted_extensions_descriptor, |
| 307 std::string* failure_message, | 235 std::string* failure_message, |
| 308 WebSocketExtensionParams* params) { | 236 WebSocketExtensionParams* params) { |
| 309 void* state = nullptr; | 237 void* state = nullptr; |
| 310 std::string header_value; | 238 std::string header_value; |
| 311 std::vector<std::string> header_values; | 239 std::vector<std::string> header_values; |
| 312 // TODO(ricea): If adding support for additional extensions, generalise this | 240 // TODO(ricea): If adding support for additional extensions, generalise this |
| 313 // code. | 241 // code. |
| 314 bool seen_permessage_deflate = false; | 242 bool seen_permessage_deflate = false; |
| (...skipping 10 matching lines...) Expand all Loading... |
| 325 } | 253 } |
| 326 | 254 |
| 327 const std::vector<WebSocketExtension>& extensions = parser.extensions(); | 255 const std::vector<WebSocketExtension>& extensions = parser.extensions(); |
| 328 for (const auto& extension : extensions) { | 256 for (const auto& extension : extensions) { |
| 329 if (extension.name() == "permessage-deflate") { | 257 if (extension.name() == "permessage-deflate") { |
| 330 if (seen_permessage_deflate) { | 258 if (seen_permessage_deflate) { |
| 331 *failure_message = "Received duplicate permessage-deflate response"; | 259 *failure_message = "Received duplicate permessage-deflate response"; |
| 332 return false; | 260 return false; |
| 333 } | 261 } |
| 334 seen_permessage_deflate = true; | 262 seen_permessage_deflate = true; |
| 335 | 263 auto& deflate_parameters = params->deflate_parameters; |
| 336 if (!ValidatePerMessageDeflateExtension(extension, failure_message, | 264 if (!deflate_parameters.Initialize(extension, failure_message) || |
| 337 params)) { | 265 !deflate_parameters.IsValidAsResponse(failure_message)) { |
| 266 *failure_message = "Error in permessage-deflate: " + *failure_message; |
| 338 return false; | 267 return false; |
| 339 } | 268 } |
| 269 // Note that we don't have to check the request-response compatibility |
| 270 // here because we send a request compatible with any valid responses. |
| 271 // TODO(yhirano): Place a DCHECK here. |
| 272 |
| 340 header_values.push_back(header_value); | 273 header_values.push_back(header_value); |
| 341 } else { | 274 } else { |
| 342 *failure_message = "Found an unsupported extension '" + | 275 *failure_message = "Found an unsupported extension '" + |
| 343 extension.name() + | 276 extension.name() + |
| 344 "' in 'Sec-WebSocket-Extensions' header"; | 277 "' in 'Sec-WebSocket-Extensions' header"; |
| 345 return false; | 278 return false; |
| 346 } | 279 } |
| 347 } | 280 } |
| 348 } | 281 } |
| 349 *accepted_extensions_descriptor = base::JoinString(header_values, ", "); | 282 *accepted_extensions_descriptor = base::JoinString(header_values, ", "); |
| 283 params->deflate_enabled = seen_permessage_deflate; |
| 350 return true; | 284 return true; |
| 351 } | 285 } |
| 352 | 286 |
| 353 } // namespace | 287 } // namespace |
| 354 | 288 |
| 355 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( | 289 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( |
| 356 scoped_ptr<ClientSocketHandle> connection, | 290 scoped_ptr<ClientSocketHandle> connection, |
| 357 WebSocketStream::ConnectDelegate* connect_delegate, | 291 WebSocketStream::ConnectDelegate* connect_delegate, |
| 358 bool using_proxy, | 292 bool using_proxy, |
| 359 std::vector<std::string> requested_sub_protocols, | 293 std::vector<std::string> requested_sub_protocols, |
| (...skipping 164 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 524 WebSocketTransportClientSocketPool::UnlockEndpoint(state_.connection()); | 458 WebSocketTransportClientSocketPool::UnlockEndpoint(state_.connection()); |
| 525 scoped_ptr<WebSocketStream> basic_stream( | 459 scoped_ptr<WebSocketStream> basic_stream( |
| 526 new WebSocketBasicStream(state_.ReleaseConnection(), | 460 new WebSocketBasicStream(state_.ReleaseConnection(), |
| 527 state_.read_buf(), | 461 state_.read_buf(), |
| 528 sub_protocol_, | 462 sub_protocol_, |
| 529 extensions_)); | 463 extensions_)); |
| 530 DCHECK(extension_params_.get()); | 464 DCHECK(extension_params_.get()); |
| 531 if (extension_params_->deflate_enabled) { | 465 if (extension_params_->deflate_enabled) { |
| 532 UMA_HISTOGRAM_ENUMERATION( | 466 UMA_HISTOGRAM_ENUMERATION( |
| 533 "Net.WebSocket.DeflateMode", | 467 "Net.WebSocket.DeflateMode", |
| 534 extension_params_->deflate_mode, | 468 extension_params_->deflate_parameters.client_context_take_over_mode(), |
| 535 WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES); | 469 WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES); |
| 536 | 470 |
| 537 return scoped_ptr<WebSocketStream>( | 471 return scoped_ptr<WebSocketStream>(new WebSocketDeflateStream( |
| 538 new WebSocketDeflateStream(basic_stream.Pass(), | 472 basic_stream.Pass(), extension_params_->deflate_parameters, |
| 539 extension_params_->deflate_mode, | 473 scoped_ptr<WebSocketDeflatePredictor>( |
| 540 extension_params_->client_window_bits, | 474 new WebSocketDeflatePredictorImpl))); |
| 541 scoped_ptr<WebSocketDeflatePredictor>( | |
| 542 new WebSocketDeflatePredictorImpl))); | |
| 543 } else { | 475 } else { |
| 544 return basic_stream.Pass(); | 476 return basic_stream.Pass(); |
| 545 } | 477 } |
| 546 } | 478 } |
| 547 | 479 |
| 548 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( | 480 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( |
| 549 const std::string& key) { | 481 const std::string& key) { |
| 550 handshake_challenge_for_testing_.reset(new std::string(key)); | 482 handshake_challenge_for_testing_.reset(new std::string(key)); |
| 551 } | 483 } |
| 552 | 484 |
| (...skipping 92 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 645 set_failure_message("Error during WebSocket handshake: " + failure_message); | 577 set_failure_message("Error during WebSocket handshake: " + failure_message); |
| 646 return ERR_INVALID_RESPONSE; | 578 return ERR_INVALID_RESPONSE; |
| 647 } | 579 } |
| 648 | 580 |
| 649 void WebSocketBasicHandshakeStream::set_failure_message( | 581 void WebSocketBasicHandshakeStream::set_failure_message( |
| 650 const std::string& failure_message) { | 582 const std::string& failure_message) { |
| 651 *failure_message_ = failure_message; | 583 *failure_message_ = failure_message; |
| 652 } | 584 } |
| 653 | 585 |
| 654 } // namespace net | 586 } // namespace net |
| OLD | NEW |