Index: net/websockets/websocket_basic_handshake_stream.cc |
diff --git a/net/websockets/websocket_basic_handshake_stream.cc b/net/websockets/websocket_basic_handshake_stream.cc |
index 73a10453be0520d33aa04a56db34246abf654c94..e30cf9ed2642d320a4166e76bf12d299d03a1caf 100644 |
--- a/net/websockets/websocket_basic_handshake_stream.cc |
+++ b/net/websockets/websocket_basic_handshake_stream.cc |
@@ -6,6 +6,7 @@ |
#include <algorithm> |
#include <iterator> |
+#include <string> |
#include "base/base64.h" |
#include "base/basictypes.h" |
@@ -13,6 +14,7 @@ |
#include "base/containers/hash_tables.h" |
#include "base/stl_util.h" |
#include "base/strings/string_util.h" |
+#include "base/strings/stringprintf.h" |
#include "crypto/random.h" |
#include "net/http/http_request_headers.h" |
#include "net/http/http_request_info.h" |
@@ -22,6 +24,7 @@ |
#include "net/http/http_stream_parser.h" |
#include "net/socket/client_socket_handle.h" |
#include "net/websockets/websocket_basic_stream.h" |
+#include "net/websockets/websocket_extension_parser.h" |
#include "net/websockets/websocket_handshake_constants.h" |
#include "net/websockets/websocket_handshake_handler.h" |
#include "net/websockets/websocket_stream.h" |
@@ -29,6 +32,23 @@ |
namespace net { |
namespace { |
+enum GetHeaderResult { |
+ GET_HEADER_OK, |
+ GET_HEADER_MISSING, |
+ GET_HEADER_MULTIPLE, |
+}; |
+ |
+std::string MissingHeaderMessage(const std::string& header_name) { |
+ return std::string("'") + header_name + "' header is missing"; |
+} |
+ |
+std::string MultipleHeaderValuesMessage(const std::string& header_name) { |
+ return |
+ std::string("'") + |
+ header_name + |
+ "' header must not appear more than once in a response"; |
+} |
+ |
std::string GenerateHandshakeChallenge() { |
std::string raw_challenge(websockets::kRawChallengeLength, '\0'); |
crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); |
@@ -45,57 +65,162 @@ void AddVectorHeaderIfNonEmpty(const char* name, |
headers->SetHeader(name, JoinString(value, ", ")); |
} |
-// If |case_sensitive| is false, then |value| must be in lower-case. |
-bool ValidateSingleTokenHeader( |
- const scoped_refptr<HttpResponseHeaders>& headers, |
- const base::StringPiece& name, |
- const std::string& value, |
- bool case_sensitive) { |
+GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers, |
+ const base::StringPiece& name, |
+ std::string* value) { |
void* state = NULL; |
- std::string token; |
- int tokens = 0; |
- bool has_value = false; |
- while (headers->EnumerateHeader(&state, name, &token)) { |
- if (++tokens > 1) |
- return false; |
- has_value = case_sensitive ? value == token |
- : LowerCaseEqualsASCII(token, value.c_str()); |
+ size_t num_values = 0; |
+ std::string temp_value; |
+ while (headers->EnumerateHeader(&state, name, &temp_value)) { |
+ if (++num_values > 1) |
+ return GET_HEADER_MULTIPLE; |
+ *value = temp_value; |
+ } |
+ return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING; |
+} |
+ |
+bool ValidateHeaderHasSingleValue(GetHeaderResult result, |
+ const std::string& header_name, |
+ std::string* failure_message) { |
+ if (result == GET_HEADER_MISSING) { |
+ *failure_message = MissingHeaderMessage(header_name); |
+ return false; |
+ } |
+ if (result == GET_HEADER_MULTIPLE) { |
+ *failure_message = MultipleHeaderValuesMessage(header_name); |
+ return false; |
+ } |
+ DCHECK_EQ(result, GET_HEADER_OK); |
+ return true; |
+} |
+ |
+bool ValidateUpgrade(const HttpResponseHeaders* headers, |
+ std::string* failure_message) { |
+ std::string value; |
+ GetHeaderResult result = |
+ GetSingleHeaderValue(headers, websockets::kUpgrade, &value); |
+ if (!ValidateHeaderHasSingleValue(result, |
+ websockets::kUpgrade, |
+ failure_message)) { |
+ return false; |
+ } |
+ |
+ if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) { |
+ *failure_message = |
+ "'Upgrade' header value is not 'WebSocket': " + value; |
+ return false; |
+ } |
+ return true; |
+} |
+ |
+bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers, |
+ const std::string& expected, |
+ std::string* failure_message) { |
+ std::string actual; |
+ GetHeaderResult result = |
+ GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual); |
+ if (!ValidateHeaderHasSingleValue(result, |
+ websockets::kSecWebSocketAccept, |
+ failure_message)) { |
+ return false; |
+ } |
+ |
+ if (expected != actual) { |
+ *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value"; |
+ return false; |
+ } |
+ return true; |
+} |
+ |
+bool ValidateConnection(const HttpResponseHeaders* headers, |
+ std::string* failure_message) { |
+ // Connection header is permitted to contain other tokens. |
+ if (!headers->HasHeader(HttpRequestHeaders::kConnection)) { |
+ *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection); |
+ return false; |
} |
- return has_value; |
+ if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection, |
+ websockets::kUpgrade)) { |
+ *failure_message = "'Connection' header value must contain 'Upgrade'"; |
+ return false; |
+ } |
+ return true; |
} |
bool ValidateSubProtocol( |
- const scoped_refptr<HttpResponseHeaders>& headers, |
+ const HttpResponseHeaders* headers, |
const std::vector<std::string>& requested_sub_protocols, |
- std::string* sub_protocol) { |
+ std::string* sub_protocol, |
+ std::string* failure_message) { |
void* state = NULL; |
- std::string token; |
+ std::string value; |
base::hash_set<std::string> requested_set(requested_sub_protocols.begin(), |
requested_sub_protocols.end()); |
- int accepted = 0; |
- while (headers->EnumerateHeader( |
- &state, websockets::kSecWebSocketProtocol, &token)) { |
- if (requested_set.count(token) == 0) |
- return false; |
+ int count = 0; |
+ bool has_multiple_protocols = false; |
+ bool has_invalid_protocol = false; |
+ |
+ while (!has_invalid_protocol || !has_multiple_protocols) { |
+ std::string temp_value; |
+ if (!headers->EnumerateHeader( |
+ &state, websockets::kSecWebSocketProtocol, &temp_value)) |
+ break; |
+ value = temp_value; |
+ if (requested_set.count(value) == 0) |
+ has_invalid_protocol = true; |
+ if (++count > 1) |
+ has_multiple_protocols = true; |
+ } |
- *sub_protocol = token; |
- // The server is only allowed to accept one protocol. |
- if (++accepted > 1) |
- return false; |
+ if (has_multiple_protocols) { |
+ *failure_message = |
+ MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol); |
+ return false; |
+ } else if (count > 0 && requested_sub_protocols.size() == 0) { |
+ *failure_message = |
+ std::string("Response must not include 'Sec-WebSocket-Protocol' " |
+ "header if not present in request: ") |
+ + value; |
+ return false; |
+ } else if (has_invalid_protocol) { |
+ *failure_message = |
+ "'Sec-WebSocket-Protocol' header value '" + |
+ value + |
+ "' in response does not match any of sent values"; |
+ return false; |
+ } else if (requested_sub_protocols.size() > 0 && count == 0) { |
+ *failure_message = |
+ "Sent non-empty 'Sec-WebSocket-Protocol' header " |
+ "but no response was received"; |
+ return false; |
} |
- // If the browser requested > 0 protocols, the server is required to accept |
- // one. |
- return requested_set.empty() || accepted == 1; |
+ *sub_protocol = value; |
+ return true; |
} |
-bool ValidateExtensions(const scoped_refptr<HttpResponseHeaders>& headers, |
+bool ValidateExtensions(const HttpResponseHeaders* headers, |
const std::vector<std::string>& requested_extensions, |
- std::string* extensions) { |
+ std::string* extensions, |
+ std::string* failure_message) { |
void* state = NULL; |
- std::string token; |
+ std::string value; |
while (headers->EnumerateHeader( |
- &state, websockets::kSecWebSocketExtensions, &token)) { |
+ &state, websockets::kSecWebSocketExtensions, &value)) { |
+ WebSocketExtensionParser parser; |
+ parser.Parse(value); |
+ if (parser.has_error()) { |
+ // TODO(yhirano) Set appropriate failure message. |
+ *failure_message = |
+ "'Sec-WebSocket-Extensions' header value is " |
+ "rejected by the parser: " + |
+ value; |
+ return false; |
+ } |
// TODO(ricea): Accept permessage-deflate with valid parameters. |
+ *failure_message = |
+ "Found an unsupported extension '" + |
+ parser.extension().name() + |
+ "' in 'Sec-WebSocket-Extensions' header"; |
return false; |
} |
return true; |
@@ -267,6 +392,10 @@ void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( |
handshake_challenge_for_testing_.reset(new std::string(key)); |
} |
+std::string WebSocketBasicHandshakeStream::GetFailureMessage() const { |
+ return failure_message_; |
+} |
+ |
void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback( |
const CompletionCallback& callback, |
int result) { |
@@ -292,26 +421,30 @@ int WebSocketBasicHandshakeStream::ValidateResponse() { |
// Other status codes are potentially risky (see the warnings in the |
// WHATWG WebSocket API spec) and so are dropped by default. |
default: |
+ failure_message_ = base::StringPrintf("Unexpected status code: %d", |
+ headers->response_code()); |
return ERR_INVALID_RESPONSE; |
} |
} |
int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( |
const scoped_refptr<HttpResponseHeaders>& headers) { |
- if (ValidateSingleTokenHeader(headers, |
- websockets::kUpgrade, |
- websockets::kWebSocketLowercase, |
- false) && |
- ValidateSingleTokenHeader(headers, |
- websockets::kSecWebSocketAccept, |
- handshake_challenge_response_, |
- true) && |
- headers->HasHeaderValue(HttpRequestHeaders::kConnection, |
- websockets::kUpgrade) && |
- ValidateSubProtocol(headers, requested_sub_protocols_, &sub_protocol_) && |
- ValidateExtensions(headers, requested_extensions_, &extensions_)) { |
+ if (ValidateUpgrade(headers.get(), &failure_message_) && |
+ ValidateSecWebSocketAccept(headers.get(), |
+ handshake_challenge_response_, |
+ &failure_message_) && |
+ ValidateConnection(headers.get(), &failure_message_) && |
+ ValidateSubProtocol(headers.get(), |
+ requested_sub_protocols_, |
+ &sub_protocol_, |
+ &failure_message_) && |
+ ValidateExtensions(headers.get(), |
+ requested_extensions_, |
+ &extensions_, |
+ &failure_message_)) { |
return OK; |
} |
+ failure_message_ = "Error during WebSocket handshake: " + failure_message_; |
return ERR_INVALID_RESPONSE; |
} |