Chromium Code Reviews| Index: net/websockets/websocket_basic_handshake_stream.cc |
| diff --git a/net/websockets/websocket_basic_handshake_stream.cc b/net/websockets/websocket_basic_handshake_stream.cc |
| index 73a10453be0520d33aa04a56db34246abf654c94..aeae2baf2171c07bc53838e0ef41296120359462 100644 |
| --- a/net/websockets/websocket_basic_handshake_stream.cc |
| +++ b/net/websockets/websocket_basic_handshake_stream.cc |
| @@ -6,6 +6,7 @@ |
| #include <algorithm> |
| #include <iterator> |
| +#include <string> |
| #include "base/base64.h" |
| #include "base/basictypes.h" |
| @@ -13,6 +14,7 @@ |
| #include "base/containers/hash_tables.h" |
| #include "base/stl_util.h" |
| #include "base/strings/string_util.h" |
| +#include "base/strings/stringprintf.h" |
| #include "crypto/random.h" |
| #include "net/http/http_request_headers.h" |
| #include "net/http/http_request_info.h" |
| @@ -22,6 +24,7 @@ |
| #include "net/http/http_stream_parser.h" |
| #include "net/socket/client_socket_handle.h" |
| #include "net/websockets/websocket_basic_stream.h" |
| +#include "net/websockets/websocket_extension_parser.h" |
| #include "net/websockets/websocket_handshake_constants.h" |
| #include "net/websockets/websocket_handshake_handler.h" |
| #include "net/websockets/websocket_stream.h" |
| @@ -29,6 +32,23 @@ |
| namespace net { |
| namespace { |
| +enum GetHeaderResult { |
| + GET_HEADER_OK, |
| + GET_HEADER_MISSING, |
| + GET_HEADER_MULTIPLE, |
| +}; |
| + |
| +std::string MissingHeaderMessage(const std::string& header_name) { |
| + return std::string("'") + header_name + "' header is missing"; |
| +} |
| + |
| +std::string MultipleHeaderValuesMessage(const std::string& header_name) { |
| + return |
| + std::string("'") + |
| + header_name + |
| + "' header must not appear more than once in a response"; |
| +} |
| + |
| std::string GenerateHandshakeChallenge() { |
| std::string raw_challenge(websockets::kRawChallengeLength, '\0'); |
| crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); |
| @@ -45,57 +65,162 @@ void AddVectorHeaderIfNonEmpty(const char* name, |
| headers->SetHeader(name, JoinString(value, ", ")); |
| } |
| -// If |case_sensitive| is false, then |value| must be in lower-case. |
| -bool ValidateSingleTokenHeader( |
| - const scoped_refptr<HttpResponseHeaders>& headers, |
| - const base::StringPiece& name, |
| - const std::string& value, |
| - bool case_sensitive) { |
| +GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers, |
| + const base::StringPiece& name, |
| + std::string* value) { |
| void* state = NULL; |
| - std::string token; |
| - int tokens = 0; |
| - bool has_value = false; |
| - while (headers->EnumerateHeader(&state, name, &token)) { |
| - if (++tokens > 1) |
| - return false; |
| - has_value = case_sensitive ? value == token |
| - : LowerCaseEqualsASCII(token, value.c_str()); |
| + size_t num_values = 0; |
| + std::string temp_value; |
| + while (headers->EnumerateHeader(&state, name, &temp_value)) { |
| + if (++num_values > 1) |
| + return GET_HEADER_MULTIPLE; |
| + *value = temp_value; |
| + } |
| + return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING; |
| +} |
| + |
| +bool ValidateHeaderHasSingleValue(GetHeaderResult result, |
| + const std::string& header_name, |
| + std::string* failure_message) { |
| + if (result == GET_HEADER_MISSING) { |
| + *failure_message = MissingHeaderMessage(header_name); |
| + return false; |
| + } |
| + if (result == GET_HEADER_MULTIPLE) { |
| + *failure_message = MultipleHeaderValuesMessage(header_name); |
| + return false; |
| + } |
| + DCHECK_EQ(result, GET_HEADER_OK); |
| + return true; |
| +} |
| + |
| +bool ValidateUpgrade(const HttpResponseHeaders* headers, |
| + std::string* failure_message) { |
| + std::string value; |
| + GetHeaderResult result = |
| + GetSingleHeaderValue(headers, websockets::kUpgrade, &value); |
| + if (!ValidateHeaderHasSingleValue(result, |
| + websockets::kUpgrade, |
| + failure_message)) { |
| + return false; |
| + } |
| + |
| + if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) { |
| + *failure_message = |
| + "'Upgrade' header value is not 'WebSocket': " + value; |
| + return false; |
| + } |
| + return true; |
| +} |
| + |
| +bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers, |
| + const std::string& expected, |
| + std::string* failure_message) { |
| + std::string actual; |
| + GetHeaderResult result = |
| + GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual); |
| + if (!ValidateHeaderHasSingleValue(result, |
| + websockets::kSecWebSocketAccept, |
| + failure_message)) { |
| + return false; |
| + } |
| + |
| + if (expected != actual) { |
| + *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value"; |
| + return false; |
| + } |
| + return true; |
| +} |
| + |
| +bool ValidateConnection(const HttpResponseHeaders* headers, |
| + std::string* failure_message) { |
| + // Connection header is permitted to contain other tokens. |
| + if (!headers->HasHeader(HttpRequestHeaders::kConnection)) { |
| + *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection); |
| + return false; |
| } |
| - return has_value; |
| + if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection, |
| + websockets::kUpgrade)) { |
| + *failure_message = "'Connection' header value must contain 'Upgrade'"; |
| + return false; |
| + } |
| + return true; |
| } |
| bool ValidateSubProtocol( |
| - const scoped_refptr<HttpResponseHeaders>& headers, |
| + const HttpResponseHeaders* headers, |
| const std::vector<std::string>& requested_sub_protocols, |
| - std::string* sub_protocol) { |
| + std::string* sub_protocol, |
| + std::string* failure_message) { |
| void* state = NULL; |
| - std::string token; |
| + std::string last_parsed; |
|
tyoshino (SeeGerritForStatus)
2014/01/09 05:53:55
i suggested last_parsed for the temporary variable
yhirano
2014/01/09 06:04:19
Done.
|
| base::hash_set<std::string> requested_set(requested_sub_protocols.begin(), |
| requested_sub_protocols.end()); |
| - int accepted = 0; |
| - while (headers->EnumerateHeader( |
| - &state, websockets::kSecWebSocketProtocol, &token)) { |
| - if (requested_set.count(token) == 0) |
| - return false; |
| + int count = 0; |
| + bool has_multiple_protocols = false; |
| + bool has_invalid_protocol = false; |
| + |
| + while (!has_invalid_protocol || !has_multiple_protocols) { |
| + std::string temp_value; |
| + if (!headers->EnumerateHeader( |
| + &state, websockets::kSecWebSocketProtocol, &temp_value)) |
| + break; |
| + last_parsed = temp_value; |
| + if (requested_set.count(last_parsed) == 0) |
| + has_invalid_protocol = true; |
| + if (++count > 1) |
| + has_multiple_protocols = true; |
| + } |
| - *sub_protocol = token; |
| - // The server is only allowed to accept one protocol. |
| - if (++accepted > 1) |
| - return false; |
| + if (has_multiple_protocols) { |
| + *failure_message = |
| + MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol); |
| + return false; |
| + } else if (count > 0 && requested_sub_protocols.size() == 0) { |
| + *failure_message = |
| + std::string("Response must not include 'Sec-WebSocket-Protocol' " |
| + "header if not present in request: ") |
| + + last_parsed; |
| + return false; |
| + } else if (has_invalid_protocol) { |
| + *failure_message = |
| + "'Sec-WebSocket-Protocol' header value '" + |
| + last_parsed + |
| + "' in response does not match any of sent values"; |
| + return false; |
| + } else if (requested_sub_protocols.size() > 0 && count == 0) { |
| + *failure_message = |
| + "Sent non-empty 'Sec-WebSocket-Protocol' header " |
| + "but no response was received"; |
| + return false; |
| } |
| - // If the browser requested > 0 protocols, the server is required to accept |
| - // one. |
| - return requested_set.empty() || accepted == 1; |
| + *sub_protocol = last_parsed; |
| + return true; |
| } |
| -bool ValidateExtensions(const scoped_refptr<HttpResponseHeaders>& headers, |
| +bool ValidateExtensions(const HttpResponseHeaders* headers, |
| const std::vector<std::string>& requested_extensions, |
| - std::string* extensions) { |
| + std::string* extensions, |
| + std::string* failure_message) { |
| void* state = NULL; |
| - std::string token; |
| + std::string value; |
| while (headers->EnumerateHeader( |
| - &state, websockets::kSecWebSocketExtensions, &token)) { |
| + &state, websockets::kSecWebSocketExtensions, &value)) { |
| + WebSocketExtensionParser parser; |
| + parser.Parse(value); |
| + if (parser.has_error()) { |
| + // TODO(yhirano) Set appropriate failure message. |
| + *failure_message = |
| + "'Sec-WebSocket-Extensions' header value is " |
| + "rejected by the parser: " + |
| + value; |
| + return false; |
| + } |
| // TODO(ricea): Accept permessage-deflate with valid parameters. |
| + *failure_message = |
| + "Found an unsupported extension '" + |
| + parser.extension().name() + |
| + "' in 'Sec-WebSocket-Extensions' header"; |
| return false; |
| } |
| return true; |
| @@ -267,6 +392,10 @@ void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( |
| handshake_challenge_for_testing_.reset(new std::string(key)); |
| } |
| +std::string WebSocketBasicHandshakeStream::GetFailureMessage() const { |
| + return failure_message_; |
| +} |
| + |
| void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback( |
| const CompletionCallback& callback, |
| int result) { |
| @@ -292,26 +421,30 @@ int WebSocketBasicHandshakeStream::ValidateResponse() { |
| // Other status codes are potentially risky (see the warnings in the |
| // WHATWG WebSocket API spec) and so are dropped by default. |
| default: |
| + failure_message_ = base::StringPrintf("Unexpected status code: %d", |
| + headers->response_code()); |
| return ERR_INVALID_RESPONSE; |
| } |
| } |
| int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( |
| const scoped_refptr<HttpResponseHeaders>& headers) { |
| - if (ValidateSingleTokenHeader(headers, |
| - websockets::kUpgrade, |
| - websockets::kWebSocketLowercase, |
| - false) && |
| - ValidateSingleTokenHeader(headers, |
| - websockets::kSecWebSocketAccept, |
| - handshake_challenge_response_, |
| - true) && |
| - headers->HasHeaderValue(HttpRequestHeaders::kConnection, |
| - websockets::kUpgrade) && |
| - ValidateSubProtocol(headers, requested_sub_protocols_, &sub_protocol_) && |
| - ValidateExtensions(headers, requested_extensions_, &extensions_)) { |
| + if (ValidateUpgrade(headers.get(), &failure_message_) && |
| + ValidateSecWebSocketAccept(headers.get(), |
| + handshake_challenge_response_, |
| + &failure_message_) && |
| + ValidateConnection(headers.get(), &failure_message_) && |
| + ValidateSubProtocol(headers.get(), |
| + requested_sub_protocols_, |
| + &sub_protocol_, |
| + &failure_message_) && |
| + ValidateExtensions(headers.get(), |
| + requested_extensions_, |
| + &extensions_, |
| + &failure_message_)) { |
| return OK; |
| } |
| + failure_message_ = "Error during WebSocket handshake: " + failure_message_; |
| return ERR_INVALID_RESPONSE; |
| } |