| OLD | NEW |
| 1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/websockets/websocket_basic_handshake_stream.h" | 5 #include "net/websockets/websocket_basic_handshake_stream.h" |
| 6 | 6 |
| 7 #include <algorithm> | 7 #include <algorithm> |
| 8 #include <iterator> | 8 #include <iterator> |
| 9 #include <set> | 9 #include <set> |
| 10 #include <string> | 10 #include <string> |
| (...skipping 51 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 62 GET_HEADER_OK, | 62 GET_HEADER_OK, |
| 63 GET_HEADER_MISSING, | 63 GET_HEADER_MISSING, |
| 64 GET_HEADER_MULTIPLE, | 64 GET_HEADER_MULTIPLE, |
| 65 }; | 65 }; |
| 66 | 66 |
| 67 std::string MissingHeaderMessage(const std::string& header_name) { | 67 std::string MissingHeaderMessage(const std::string& header_name) { |
| 68 return std::string("'") + header_name + "' header is missing"; | 68 return std::string("'") + header_name + "' header is missing"; |
| 69 } | 69 } |
| 70 | 70 |
| 71 std::string MultipleHeaderValuesMessage(const std::string& header_name) { | 71 std::string MultipleHeaderValuesMessage(const std::string& header_name) { |
| 72 return | 72 return std::string("'") + header_name + |
| 73 std::string("'") + | 73 "' header must not appear more than once in a response"; |
| 74 header_name + | |
| 75 "' header must not appear more than once in a response"; | |
| 76 } | 74 } |
| 77 | 75 |
| 78 std::string GenerateHandshakeChallenge() { | 76 std::string GenerateHandshakeChallenge() { |
| 79 std::string raw_challenge(websockets::kRawChallengeLength, '\0'); | 77 std::string raw_challenge(websockets::kRawChallengeLength, '\0'); |
| 80 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); | 78 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); |
| 81 std::string encoded_challenge; | 79 std::string encoded_challenge; |
| 82 base::Base64Encode(raw_challenge, &encoded_challenge); | 80 base::Base64Encode(raw_challenge, &encoded_challenge); |
| 83 return encoded_challenge; | 81 return encoded_challenge; |
| 84 } | 82 } |
| 85 | 83 |
| (...skipping 32 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 118 } | 116 } |
| 119 DCHECK_EQ(result, GET_HEADER_OK); | 117 DCHECK_EQ(result, GET_HEADER_OK); |
| 120 return true; | 118 return true; |
| 121 } | 119 } |
| 122 | 120 |
| 123 bool ValidateUpgrade(const HttpResponseHeaders* headers, | 121 bool ValidateUpgrade(const HttpResponseHeaders* headers, |
| 124 std::string* failure_message) { | 122 std::string* failure_message) { |
| 125 std::string value; | 123 std::string value; |
| 126 GetHeaderResult result = | 124 GetHeaderResult result = |
| 127 GetSingleHeaderValue(headers, websockets::kUpgrade, &value); | 125 GetSingleHeaderValue(headers, websockets::kUpgrade, &value); |
| 128 if (!ValidateHeaderHasSingleValue(result, | 126 if (!ValidateHeaderHasSingleValue( |
| 129 websockets::kUpgrade, | 127 result, websockets::kUpgrade, failure_message)) { |
| 130 failure_message)) { | |
| 131 return false; | 128 return false; |
| 132 } | 129 } |
| 133 | 130 |
| 134 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) { | 131 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) { |
| 135 *failure_message = | 132 *failure_message = "'Upgrade' header value is not 'WebSocket': " + value; |
| 136 "'Upgrade' header value is not 'WebSocket': " + value; | |
| 137 return false; | 133 return false; |
| 138 } | 134 } |
| 139 return true; | 135 return true; |
| 140 } | 136 } |
| 141 | 137 |
| 142 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers, | 138 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers, |
| 143 const std::string& expected, | 139 const std::string& expected, |
| 144 std::string* failure_message) { | 140 std::string* failure_message) { |
| 145 std::string actual; | 141 std::string actual; |
| 146 GetHeaderResult result = | 142 GetHeaderResult result = |
| 147 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual); | 143 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual); |
| 148 if (!ValidateHeaderHasSingleValue(result, | 144 if (!ValidateHeaderHasSingleValue( |
| 149 websockets::kSecWebSocketAccept, | 145 result, websockets::kSecWebSocketAccept, failure_message)) { |
| 150 failure_message)) { | |
| 151 return false; | 146 return false; |
| 152 } | 147 } |
| 153 | 148 |
| 154 if (expected != actual) { | 149 if (expected != actual) { |
| 155 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value"; | 150 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value"; |
| 156 return false; | 151 return false; |
| 157 } | 152 } |
| 158 return true; | 153 return true; |
| 159 } | 154 } |
| 160 | 155 |
| (...skipping 35 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 196 has_invalid_protocol = true; | 191 has_invalid_protocol = true; |
| 197 if (++count > 1) | 192 if (++count > 1) |
| 198 has_multiple_protocols = true; | 193 has_multiple_protocols = true; |
| 199 } | 194 } |
| 200 | 195 |
| 201 if (has_multiple_protocols) { | 196 if (has_multiple_protocols) { |
| 202 *failure_message = | 197 *failure_message = |
| 203 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol); | 198 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol); |
| 204 return false; | 199 return false; |
| 205 } else if (count > 0 && requested_sub_protocols.size() == 0) { | 200 } else if (count > 0 && requested_sub_protocols.size() == 0) { |
| 206 *failure_message = | 201 *failure_message = std::string( |
| 207 std::string("Response must not include 'Sec-WebSocket-Protocol' " | 202 "Response must not include 'Sec-WebSocket-Protocol' " |
| 208 "header if not present in request: ") | 203 "header if not present in request: ") + |
| 209 + value; | 204 value; |
| 210 return false; | 205 return false; |
| 211 } else if (has_invalid_protocol) { | 206 } else if (has_invalid_protocol) { |
| 212 *failure_message = | 207 *failure_message = "'Sec-WebSocket-Protocol' header value '" + value + |
| 213 "'Sec-WebSocket-Protocol' header value '" + | 208 "' in response does not match any of sent values"; |
| 214 value + | |
| 215 "' in response does not match any of sent values"; | |
| 216 return false; | 209 return false; |
| 217 } else if (requested_sub_protocols.size() > 0 && count == 0) { | 210 } else if (requested_sub_protocols.size() > 0 && count == 0) { |
| 218 *failure_message = | 211 *failure_message = |
| 219 "Sent non-empty 'Sec-WebSocket-Protocol' header " | 212 "Sent non-empty 'Sec-WebSocket-Protocol' header " |
| 220 "but no response was received"; | 213 "but no response was received"; |
| 221 return false; | 214 return false; |
| 222 } | 215 } |
| 223 *sub_protocol = value; | 216 *sub_protocol = value; |
| 224 return true; | 217 return true; |
| 225 } | 218 } |
| (...skipping 13 matching lines...) Expand all Loading... |
| 239 static const char kMaxWindowBits[] = "max_window_bits"; | 232 static const char kMaxWindowBits[] = "max_window_bits"; |
| 240 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; | 233 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; |
| 241 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, | 234 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, |
| 242 the_strings_server_and_client_must_be_the_same_length); | 235 the_strings_server_and_client_must_be_the_same_length); |
| 243 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; | 236 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; |
| 244 | 237 |
| 245 DCHECK_EQ("permessage-deflate", extension.name()); | 238 DCHECK_EQ("permessage-deflate", extension.name()); |
| 246 const ParameterVector& parameters = extension.parameters(); | 239 const ParameterVector& parameters = extension.parameters(); |
| 247 std::set<std::string> seen_names; | 240 std::set<std::string> seen_names; |
| 248 for (ParameterVector::const_iterator it = parameters.begin(); | 241 for (ParameterVector::const_iterator it = parameters.begin(); |
| 249 it != parameters.end(); ++it) { | 242 it != parameters.end(); |
| 243 ++it) { |
| 250 const std::string& name = it->name(); | 244 const std::string& name = it->name(); |
| 251 if (seen_names.count(name) != 0) { | 245 if (seen_names.count(name) != 0) { |
| 252 return DeflateError( | 246 return DeflateError( |
| 253 failure_message, | 247 failure_message, |
| 254 "Received duplicate permessage-deflate extension parameter " + name); | 248 "Received duplicate permessage-deflate extension parameter " + name); |
| 255 } | 249 } |
| 256 seen_names.insert(name); | 250 seen_names.insert(name); |
| 257 const std::string client_or_server(name, 0, kPrefixLen); | 251 const std::string client_or_server(name, 0, kPrefixLen); |
| 258 const bool is_client = (client_or_server == kClientPrefix); | 252 const bool is_client = (client_or_server == kClientPrefix); |
| 259 if (!is_client && client_or_server != kServerPrefix) { | 253 if (!is_client && client_or_server != kServerPrefix) { |
| (...skipping 36 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 296 std::string* extensions, | 290 std::string* extensions, |
| 297 std::string* failure_message, | 291 std::string* failure_message, |
| 298 WebSocketExtensionParams* params) { | 292 WebSocketExtensionParams* params) { |
| 299 void* state = NULL; | 293 void* state = NULL; |
| 300 std::string value; | 294 std::string value; |
| 301 std::vector<std::string> accepted_extensions; | 295 std::vector<std::string> accepted_extensions; |
| 302 // TODO(ricea): If adding support for additional extensions, generalise this | 296 // TODO(ricea): If adding support for additional extensions, generalise this |
| 303 // code. | 297 // code. |
| 304 bool seen_permessage_deflate = false; | 298 bool seen_permessage_deflate = false; |
| 305 while (headers->EnumerateHeader( | 299 while (headers->EnumerateHeader( |
| 306 &state, websockets::kSecWebSocketExtensions, &value)) { | 300 &state, websockets::kSecWebSocketExtensions, &value)) { |
| 307 WebSocketExtensionParser parser; | 301 WebSocketExtensionParser parser; |
| 308 parser.Parse(value); | 302 parser.Parse(value); |
| 309 if (parser.has_error()) { | 303 if (parser.has_error()) { |
| 310 // TODO(yhirano) Set appropriate failure message. | 304 // TODO(yhirano) Set appropriate failure message. |
| 311 *failure_message = | 305 *failure_message = |
| 312 "'Sec-WebSocket-Extensions' header value is " | 306 "'Sec-WebSocket-Extensions' header value is " |
| 313 "rejected by the parser: " + | 307 "rejected by the parser: " + |
| 314 value; | 308 value; |
| 315 return false; | 309 return false; |
| 316 } | 310 } |
| 317 if (parser.extension().name() == "permessage-deflate") { | 311 if (parser.extension().name() == "permessage-deflate") { |
| 318 if (seen_permessage_deflate) { | 312 if (seen_permessage_deflate) { |
| 319 *failure_message = "Received duplicate permessage-deflate response"; | 313 *failure_message = "Received duplicate permessage-deflate response"; |
| 320 return false; | 314 return false; |
| 321 } | 315 } |
| 322 seen_permessage_deflate = true; | 316 seen_permessage_deflate = true; |
| 323 if (!ValidatePerMessageDeflateExtension( | 317 if (!ValidatePerMessageDeflateExtension( |
| 324 parser.extension(), failure_message, params)) | 318 parser.extension(), failure_message, params)) |
| 325 return false; | 319 return false; |
| 326 } else { | 320 } else { |
| 327 *failure_message = | 321 *failure_message = "Found an unsupported extension '" + |
| 328 "Found an unsupported extension '" + | 322 parser.extension().name() + |
| 329 parser.extension().name() + | 323 "' in 'Sec-WebSocket-Extensions' header"; |
| 330 "' in 'Sec-WebSocket-Extensions' header"; | |
| 331 return false; | 324 return false; |
| 332 } | 325 } |
| 333 accepted_extensions.push_back(value); | 326 accepted_extensions.push_back(value); |
| 334 } | 327 } |
| 335 *extensions = JoinString(accepted_extensions, ", "); | 328 *extensions = JoinString(accepted_extensions, ", "); |
| 336 return true; | 329 return true; |
| 337 } | 330 } |
| 338 | 331 |
| 339 } // namespace | 332 } // namespace |
| 340 | 333 |
| 341 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( | 334 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( |
| 342 scoped_ptr<ClientSocketHandle> connection, | 335 scoped_ptr<ClientSocketHandle> connection, |
| 343 WebSocketStream::ConnectDelegate* connect_delegate, | 336 WebSocketStream::ConnectDelegate* connect_delegate, |
| 344 bool using_proxy, | 337 bool using_proxy, |
| 345 std::vector<std::string> requested_sub_protocols, | 338 std::vector<std::string> requested_sub_protocols, |
| 346 std::vector<std::string> requested_extensions) | 339 std::vector<std::string> requested_extensions) |
| 347 : state_(connection.release(), using_proxy), | 340 : state_(connection.release(), using_proxy), |
| 348 connect_delegate_(connect_delegate), | 341 connect_delegate_(connect_delegate), |
| 349 http_response_info_(NULL), | 342 http_response_info_(NULL), |
| 350 requested_sub_protocols_(requested_sub_protocols), | 343 requested_sub_protocols_(requested_sub_protocols), |
| 351 requested_extensions_(requested_extensions) {} | 344 requested_extensions_(requested_extensions) { |
| 345 } |
| 352 | 346 |
| 353 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {} | 347 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() { |
| 348 } |
| 354 | 349 |
| 355 int WebSocketBasicHandshakeStream::InitializeStream( | 350 int WebSocketBasicHandshakeStream::InitializeStream( |
| 356 const HttpRequestInfo* request_info, | 351 const HttpRequestInfo* request_info, |
| 357 RequestPriority priority, | 352 RequestPriority priority, |
| 358 const BoundNetLog& net_log, | 353 const BoundNetLog& net_log, |
| 359 const CompletionCallback& callback) { | 354 const CompletionCallback& callback) { |
| 360 url_ = request_info->url; | 355 url_ = request_info->url; |
| 361 state_.Initialize(request_info, priority, net_log, callback); | 356 state_.Initialize(request_info, priority, net_log, callback); |
| 362 return OK; | 357 return OK; |
| 363 } | 358 } |
| (...skipping 111 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 475 | 470 |
| 476 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) { | 471 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) { |
| 477 parser()->GetSSLInfo(ssl_info); | 472 parser()->GetSSLInfo(ssl_info); |
| 478 } | 473 } |
| 479 | 474 |
| 480 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo( | 475 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo( |
| 481 SSLCertRequestInfo* cert_request_info) { | 476 SSLCertRequestInfo* cert_request_info) { |
| 482 parser()->GetSSLCertRequestInfo(cert_request_info); | 477 parser()->GetSSLCertRequestInfo(cert_request_info); |
| 483 } | 478 } |
| 484 | 479 |
| 485 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; } | 480 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { |
| 481 return false; |
| 482 } |
| 486 | 483 |
| 487 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) { | 484 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) { |
| 488 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this); | 485 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this); |
| 489 drainer->Start(session); | 486 drainer->Start(session); |
| 490 // |drainer| will delete itself. | 487 // |drainer| will delete itself. |
| 491 } | 488 } |
| 492 | 489 |
| 493 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { | 490 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { |
| 494 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is | 491 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is |
| 495 // gone, then copy whatever has happened there over here. | 492 // gone, then copy whatever has happened there over here. |
| (...skipping 100 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 596 std::string("Error during WebSocket handshake: ") + ErrorToString(rv); | 593 std::string("Error during WebSocket handshake: ") + ErrorToString(rv); |
| 597 OnFinishOpeningHandshake(); | 594 OnFinishOpeningHandshake(); |
| 598 return rv; | 595 return rv; |
| 599 } | 596 } |
| 600 } | 597 } |
| 601 | 598 |
| 602 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( | 599 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( |
| 603 const HttpResponseHeaders* headers) { | 600 const HttpResponseHeaders* headers) { |
| 604 extension_params_.reset(new WebSocketExtensionParams); | 601 extension_params_.reset(new WebSocketExtensionParams); |
| 605 if (ValidateUpgrade(headers, &failure_message_) && | 602 if (ValidateUpgrade(headers, &failure_message_) && |
| 606 ValidateSecWebSocketAccept(headers, | 603 ValidateSecWebSocketAccept( |
| 607 handshake_challenge_response_, | 604 headers, handshake_challenge_response_, &failure_message_) && |
| 608 &failure_message_) && | |
| 609 ValidateConnection(headers, &failure_message_) && | 605 ValidateConnection(headers, &failure_message_) && |
| 610 ValidateSubProtocol(headers, | 606 ValidateSubProtocol(headers, |
| 611 requested_sub_protocols_, | 607 requested_sub_protocols_, |
| 612 &sub_protocol_, | 608 &sub_protocol_, |
| 613 &failure_message_) && | 609 &failure_message_) && |
| 614 ValidateExtensions(headers, | 610 ValidateExtensions(headers, |
| 615 requested_extensions_, | 611 requested_extensions_, |
| 616 &extensions_, | 612 &extensions_, |
| 617 &failure_message_, | 613 &failure_message_, |
| 618 extension_params_.get())) { | 614 extension_params_.get())) { |
| 619 return OK; | 615 return OK; |
| 620 } | 616 } |
| 621 failure_message_ = "Error during WebSocket handshake: " + failure_message_; | 617 failure_message_ = "Error during WebSocket handshake: " + failure_message_; |
| 622 return ERR_INVALID_RESPONSE; | 618 return ERR_INVALID_RESPONSE; |
| 623 } | 619 } |
| 624 | 620 |
| 625 } // namespace net | 621 } // namespace net |
| OLD | NEW |