| Index: net/websockets/websocket_basic_handshake_stream.cc
|
| diff --git a/net/websockets/websocket_basic_handshake_stream.cc b/net/websockets/websocket_basic_handshake_stream.cc
|
| index 1a8154b4b8b8b0863de003985825b593e756559b..abeb2540f2fd419ddf809bbd18a0aa1d278bf094 100644
|
| --- a/net/websockets/websocket_basic_handshake_stream.cc
|
| +++ b/net/websockets/websocket_basic_handshake_stream.cc
|
| @@ -13,6 +13,7 @@
|
| #include "base/containers/hash_tables.h"
|
| #include "base/stl_util.h"
|
| #include "base/strings/string_util.h"
|
| +#include "base/strings/stringprintf.h"
|
| #include "crypto/random.h"
|
| #include "net/http/http_request_headers.h"
|
| #include "net/http/http_request_info.h"
|
| @@ -22,6 +23,7 @@
|
| #include "net/http/http_stream_parser.h"
|
| #include "net/socket/client_socket_handle.h"
|
| #include "net/websockets/websocket_basic_stream.h"
|
| +#include "net/websockets/websocket_extension_parser.h"
|
| #include "net/websockets/websocket_handshake_constants.h"
|
| #include "net/websockets/websocket_handshake_handler.h"
|
| #include "net/websockets/websocket_stream.h"
|
| @@ -29,6 +31,23 @@
|
| namespace net {
|
| namespace {
|
|
|
| +enum GetHeaderResult {
|
| + GET_HEADER_OK,
|
| + GET_HEADER_MISSING,
|
| + GET_HEADER_MULTIPLE,
|
| +};
|
| +
|
| +std::string MissingHeader(const std::string& header_name) {
|
| + return std::string("'") + header_name + "' header is missing";
|
| +}
|
| +
|
| +std::string MultipleHeaderValues(const std::string& header_name) {
|
| + return
|
| + std::string("'") +
|
| + header_name +
|
| + "' header must not appear more than once in a response";
|
| +}
|
| +
|
| std::string GenerateHandshakeChallenge() {
|
| std::string raw_challenge(websockets::kRawChallengeLength, '\0');
|
| crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
|
| @@ -46,57 +65,159 @@ void AddVectorHeaderIfNonEmpty(const char* name,
|
| headers->SetHeader(name, JoinString(value, ", "));
|
| }
|
|
|
| -// If |case_sensitive| is false, then |value| must be in lower-case.
|
| -bool ValidateSingleTokenHeader(
|
| - const scoped_refptr<HttpResponseHeaders>& headers,
|
| - const base::StringPiece& name,
|
| - const std::string& value,
|
| - bool case_sensitive) {
|
| +GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers,
|
| + const base::StringPiece& name,
|
| + std::string* value) {
|
| void* state = NULL;
|
| - std::string token;
|
| int tokens = 0;
|
| - bool has_value = false;
|
| + std::string token;
|
| while (headers->EnumerateHeader(&state, name, &token)) {
|
| if (++tokens > 1)
|
| - return false;
|
| - has_value = case_sensitive ? value == token
|
| - : LowerCaseEqualsASCII(token, value.c_str());
|
| + return GET_HEADER_MULTIPLE;
|
| + *value = token;
|
| + }
|
| + return tokens > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
|
| +}
|
| +
|
| +bool ValidateHeaderHasSingleValue(GetHeaderResult result,
|
| + const std::string& header_name,
|
| + std::string* failure_message) {
|
| + if (result == GET_HEADER_MISSING) {
|
| + *failure_message = MissingHeader(header_name);
|
| + return false;
|
| + }
|
| + if (result == GET_HEADER_MULTIPLE) {
|
| + *failure_message = MultipleHeaderValues(header_name);
|
| + return false;
|
| + }
|
| + DCHECK_EQ(result, GET_HEADER_OK);
|
| + return true;
|
| +}
|
| +
|
| +bool ValidateUpgrade(const HttpResponseHeaders* headers,
|
| + std::string* failure_message) {
|
| + std::string value;
|
| + GetHeaderResult result =
|
| + GetSingleHeaderValue(headers, websockets::kUpgrade, &value);
|
| + if (!ValidateHeaderHasSingleValue(result,
|
| + websockets::kUpgrade,
|
| + failure_message)) {
|
| + return false;
|
| + }
|
| +
|
| + if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) {
|
| + *failure_message =
|
| + "'Upgrade' header value is not 'WebSocket': " + value;
|
| + return false;
|
| + }
|
| + return true;
|
| +}
|
| +
|
| +bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers,
|
| + const std::string& expected,
|
| + std::string* failure_message) {
|
| + std::string actual;
|
| + GetHeaderResult result =
|
| + GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual);
|
| + if (!ValidateHeaderHasSingleValue(result,
|
| + websockets::kSecWebSocketAccept,
|
| + failure_message)) {
|
| + return false;
|
| }
|
| - return has_value;
|
| +
|
| + if (expected != actual) {
|
| + *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
|
| + return false;
|
| + }
|
| + return true;
|
| +}
|
| +
|
| +bool ValidateConnection(const HttpResponseHeaders* headers,
|
| + std::string* failure_message) {
|
| + // Connection header is permitted to contain other tokens.
|
| + if (!headers->HasHeader(HttpRequestHeaders::kConnection)) {
|
| + *failure_message = MissingHeader(HttpRequestHeaders::kConnection);
|
| + return false;
|
| + }
|
| + if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection,
|
| + websockets::kUpgrade)) {
|
| + *failure_message = "'Connection' header value must contain 'Upgrade'";
|
| + return false;
|
| + }
|
| + return true;
|
| }
|
|
|
| bool ValidateSubProtocol(
|
| - const scoped_refptr<HttpResponseHeaders>& headers,
|
| + const HttpResponseHeaders* headers,
|
| const std::vector<std::string>& requested_sub_protocols,
|
| - std::string* sub_protocol) {
|
| + std::string* sub_protocol,
|
| + std::string* failure_message) {
|
| void* state = NULL;
|
| std::string token;
|
| + std::string last_token;
|
| base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
|
| requested_sub_protocols.end());
|
| - int accepted = 0;
|
| + int count = 0;
|
| + bool has_multiple_protocols = false;
|
| + bool has_invalid_protocol = false;
|
| +
|
| while (headers->EnumerateHeader(
|
| - &state, websockets::kSecWebSocketProtocol, &token)) {
|
| + &state, websockets::kSecWebSocketProtocol, &token) &&
|
| + !(has_multiple_protocols && has_invalid_protocol)) {
|
| if (requested_set.count(token) == 0)
|
| - return false;
|
| + has_invalid_protocol = true;
|
| + if (++count > 1)
|
| + has_multiple_protocols = true;
|
| + last_token = token;
|
| + }
|
|
|
| - *sub_protocol = token;
|
| - // The server is only allowed to accept one protocol.
|
| - if (++accepted > 1)
|
| - return false;
|
| + if (has_multiple_protocols) {
|
| + *failure_message = MultipleHeaderValues(websockets::kSecWebSocketProtocol);
|
| + return false;
|
| + } else if (count > 0 && requested_sub_protocols.size() == 0) {
|
| + *failure_message =
|
| + std::string("Response must not include 'Sec-WebSocket-Protocol' "
|
| + "header if not present in request: ")
|
| + + last_token;
|
| + return false;
|
| + } else if (has_invalid_protocol) {
|
| + *failure_message =
|
| + "'Sec-WebSocket-Protocol' header value '" +
|
| + last_token +
|
| + "' in response does not match any of sent values";
|
| + return false;
|
| + } else if (requested_sub_protocols.size() > 0 && count == 0) {
|
| + *failure_message =
|
| + "Sent non-empty 'Sec-WebSocket-Protocol' header "
|
| + "but no response was received";
|
| + return false;
|
| }
|
| - // If the browser requested > 0 protocols, the server is required to accept
|
| - // one.
|
| - return requested_set.empty() || accepted == 1;
|
| + *sub_protocol = last_token;
|
| + return true;
|
| }
|
|
|
| -bool ValidateExtensions(const scoped_refptr<HttpResponseHeaders>& headers,
|
| +bool ValidateExtensions(const HttpResponseHeaders* headers,
|
| const std::vector<std::string>& requested_extensions,
|
| - std::string* extensions) {
|
| + std::string* extensions,
|
| + std::string* failure_message) {
|
| void* state = NULL;
|
| std::string token;
|
| while (headers->EnumerateHeader(
|
| &state, websockets::kSecWebSocketExtensions, &token)) {
|
| + WebSocketExtensionParser parser;
|
| + parser.Parse(token);
|
| + if (parser.has_error()) {
|
| + // TODO(yhirano) Set appropriate failure message.
|
| + *failure_message =
|
| + "'WebSocket-Extensions' header value is rejected by the parser: " +
|
| + token;
|
| + return false;
|
| + }
|
| // TODO(ricea): Accept permessage-deflate with valid parameters.
|
| + *failure_message =
|
| + "Found an unsupported extension '" +
|
| + parser.extension().name() +
|
| + "' in 'Sec-WebSocket-Extensions' header";
|
| return false;
|
| }
|
| return true;
|
| @@ -264,6 +385,10 @@ void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
|
| handshake_challenge_for_testing_.reset(new std::string(key));
|
| }
|
|
|
| +std::string WebSocketBasicHandshakeStream::FailureMessage() const {
|
| + return failure_message_;
|
| +}
|
| +
|
| void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
|
| const CompletionCallback& callback,
|
| int result) {
|
| @@ -289,26 +414,30 @@ int WebSocketBasicHandshakeStream::ValidateResponse() {
|
| // Other status codes are potentially risky (see the warnings in the
|
| // WHATWG WebSocket API spec) and so are dropped by default.
|
| default:
|
| + failure_message_ = base::StringPrintf("Unexpected status code: %d",
|
| + headers->response_code());
|
| return ERR_INVALID_RESPONSE;
|
| }
|
| }
|
|
|
| int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
|
| const scoped_refptr<HttpResponseHeaders>& headers) {
|
| - if (ValidateSingleTokenHeader(headers,
|
| - websockets::kUpgrade,
|
| - websockets::kWebSocketLowercase,
|
| - false) &&
|
| - ValidateSingleTokenHeader(headers,
|
| - websockets::kSecWebSocketAccept,
|
| - handshake_challenge_response_,
|
| - true) &&
|
| - headers->HasHeaderValue(HttpRequestHeaders::kConnection,
|
| - websockets::kUpgrade) &&
|
| - ValidateSubProtocol(headers, requested_sub_protocols_, &sub_protocol_) &&
|
| - ValidateExtensions(headers, requested_extensions_, &extensions_)) {
|
| + if (ValidateUpgrade(headers.get(), &failure_message_) &&
|
| + ValidateSecWebSocketAccept(headers.get(),
|
| + handshake_challenge_response_,
|
| + &failure_message_) &&
|
| + ValidateConnection(headers.get(), &failure_message_) &&
|
| + ValidateSubProtocol(headers.get(),
|
| + requested_sub_protocols_,
|
| + &sub_protocol_,
|
| + &failure_message_) &&
|
| + ValidateExtensions(headers.get(),
|
| + requested_extensions_,
|
| + &extensions_,
|
| + &failure_message_)) {
|
| return OK;
|
| }
|
| + failure_message_ = "Error during WebSocket handshake: " + failure_message_;
|
| return ERR_INVALID_RESPONSE;
|
| }
|
|
|
|
|