Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(69)

Unified Diff: net/websockets/websocket_basic_handshake_stream.cc

Issue 105833003: Fail WebSocket channel when handshake fails. (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Created 6 years, 11 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
« no previous file with comments | « net/websockets/websocket_basic_handshake_stream.h ('k') | net/websockets/websocket_channel.h » ('j') | no next file with comments »
Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
Index: net/websockets/websocket_basic_handshake_stream.cc
diff --git a/net/websockets/websocket_basic_handshake_stream.cc b/net/websockets/websocket_basic_handshake_stream.cc
index 73a10453be0520d33aa04a56db34246abf654c94..e30cf9ed2642d320a4166e76bf12d299d03a1caf 100644
--- a/net/websockets/websocket_basic_handshake_stream.cc
+++ b/net/websockets/websocket_basic_handshake_stream.cc
@@ -6,6 +6,7 @@
#include <algorithm>
#include <iterator>
+#include <string>
#include "base/base64.h"
#include "base/basictypes.h"
@@ -13,6 +14,7 @@
#include "base/containers/hash_tables.h"
#include "base/stl_util.h"
#include "base/strings/string_util.h"
+#include "base/strings/stringprintf.h"
#include "crypto/random.h"
#include "net/http/http_request_headers.h"
#include "net/http/http_request_info.h"
@@ -22,6 +24,7 @@
#include "net/http/http_stream_parser.h"
#include "net/socket/client_socket_handle.h"
#include "net/websockets/websocket_basic_stream.h"
+#include "net/websockets/websocket_extension_parser.h"
#include "net/websockets/websocket_handshake_constants.h"
#include "net/websockets/websocket_handshake_handler.h"
#include "net/websockets/websocket_stream.h"
@@ -29,6 +32,23 @@
namespace net {
namespace {
+enum GetHeaderResult {
+ GET_HEADER_OK,
+ GET_HEADER_MISSING,
+ GET_HEADER_MULTIPLE,
+};
+
+std::string MissingHeaderMessage(const std::string& header_name) {
+ return std::string("'") + header_name + "' header is missing";
+}
+
+std::string MultipleHeaderValuesMessage(const std::string& header_name) {
+ return
+ std::string("'") +
+ header_name +
+ "' header must not appear more than once in a response";
+}
+
std::string GenerateHandshakeChallenge() {
std::string raw_challenge(websockets::kRawChallengeLength, '\0');
crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length());
@@ -45,57 +65,162 @@ void AddVectorHeaderIfNonEmpty(const char* name,
headers->SetHeader(name, JoinString(value, ", "));
}
-// If |case_sensitive| is false, then |value| must be in lower-case.
-bool ValidateSingleTokenHeader(
- const scoped_refptr<HttpResponseHeaders>& headers,
- const base::StringPiece& name,
- const std::string& value,
- bool case_sensitive) {
+GetHeaderResult GetSingleHeaderValue(const HttpResponseHeaders* headers,
+ const base::StringPiece& name,
+ std::string* value) {
void* state = NULL;
- std::string token;
- int tokens = 0;
- bool has_value = false;
- while (headers->EnumerateHeader(&state, name, &token)) {
- if (++tokens > 1)
- return false;
- has_value = case_sensitive ? value == token
- : LowerCaseEqualsASCII(token, value.c_str());
+ size_t num_values = 0;
+ std::string temp_value;
+ while (headers->EnumerateHeader(&state, name, &temp_value)) {
+ if (++num_values > 1)
+ return GET_HEADER_MULTIPLE;
+ *value = temp_value;
+ }
+ return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING;
+}
+
+bool ValidateHeaderHasSingleValue(GetHeaderResult result,
+ const std::string& header_name,
+ std::string* failure_message) {
+ if (result == GET_HEADER_MISSING) {
+ *failure_message = MissingHeaderMessage(header_name);
+ return false;
+ }
+ if (result == GET_HEADER_MULTIPLE) {
+ *failure_message = MultipleHeaderValuesMessage(header_name);
+ return false;
+ }
+ DCHECK_EQ(result, GET_HEADER_OK);
+ return true;
+}
+
+bool ValidateUpgrade(const HttpResponseHeaders* headers,
+ std::string* failure_message) {
+ std::string value;
+ GetHeaderResult result =
+ GetSingleHeaderValue(headers, websockets::kUpgrade, &value);
+ if (!ValidateHeaderHasSingleValue(result,
+ websockets::kUpgrade,
+ failure_message)) {
+ return false;
+ }
+
+ if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) {
+ *failure_message =
+ "'Upgrade' header value is not 'WebSocket': " + value;
+ return false;
+ }
+ return true;
+}
+
+bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers,
+ const std::string& expected,
+ std::string* failure_message) {
+ std::string actual;
+ GetHeaderResult result =
+ GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual);
+ if (!ValidateHeaderHasSingleValue(result,
+ websockets::kSecWebSocketAccept,
+ failure_message)) {
+ return false;
+ }
+
+ if (expected != actual) {
+ *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value";
+ return false;
+ }
+ return true;
+}
+
+bool ValidateConnection(const HttpResponseHeaders* headers,
+ std::string* failure_message) {
+ // Connection header is permitted to contain other tokens.
+ if (!headers->HasHeader(HttpRequestHeaders::kConnection)) {
+ *failure_message = MissingHeaderMessage(HttpRequestHeaders::kConnection);
+ return false;
}
- return has_value;
+ if (!headers->HasHeaderValue(HttpRequestHeaders::kConnection,
+ websockets::kUpgrade)) {
+ *failure_message = "'Connection' header value must contain 'Upgrade'";
+ return false;
+ }
+ return true;
}
bool ValidateSubProtocol(
- const scoped_refptr<HttpResponseHeaders>& headers,
+ const HttpResponseHeaders* headers,
const std::vector<std::string>& requested_sub_protocols,
- std::string* sub_protocol) {
+ std::string* sub_protocol,
+ std::string* failure_message) {
void* state = NULL;
- std::string token;
+ std::string value;
base::hash_set<std::string> requested_set(requested_sub_protocols.begin(),
requested_sub_protocols.end());
- int accepted = 0;
- while (headers->EnumerateHeader(
- &state, websockets::kSecWebSocketProtocol, &token)) {
- if (requested_set.count(token) == 0)
- return false;
+ int count = 0;
+ bool has_multiple_protocols = false;
+ bool has_invalid_protocol = false;
+
+ while (!has_invalid_protocol || !has_multiple_protocols) {
+ std::string temp_value;
+ if (!headers->EnumerateHeader(
+ &state, websockets::kSecWebSocketProtocol, &temp_value))
+ break;
+ value = temp_value;
+ if (requested_set.count(value) == 0)
+ has_invalid_protocol = true;
+ if (++count > 1)
+ has_multiple_protocols = true;
+ }
- *sub_protocol = token;
- // The server is only allowed to accept one protocol.
- if (++accepted > 1)
- return false;
+ if (has_multiple_protocols) {
+ *failure_message =
+ MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol);
+ return false;
+ } else if (count > 0 && requested_sub_protocols.size() == 0) {
+ *failure_message =
+ std::string("Response must not include 'Sec-WebSocket-Protocol' "
+ "header if not present in request: ")
+ + value;
+ return false;
+ } else if (has_invalid_protocol) {
+ *failure_message =
+ "'Sec-WebSocket-Protocol' header value '" +
+ value +
+ "' in response does not match any of sent values";
+ return false;
+ } else if (requested_sub_protocols.size() > 0 && count == 0) {
+ *failure_message =
+ "Sent non-empty 'Sec-WebSocket-Protocol' header "
+ "but no response was received";
+ return false;
}
- // If the browser requested > 0 protocols, the server is required to accept
- // one.
- return requested_set.empty() || accepted == 1;
+ *sub_protocol = value;
+ return true;
}
-bool ValidateExtensions(const scoped_refptr<HttpResponseHeaders>& headers,
+bool ValidateExtensions(const HttpResponseHeaders* headers,
const std::vector<std::string>& requested_extensions,
- std::string* extensions) {
+ std::string* extensions,
+ std::string* failure_message) {
void* state = NULL;
- std::string token;
+ std::string value;
while (headers->EnumerateHeader(
- &state, websockets::kSecWebSocketExtensions, &token)) {
+ &state, websockets::kSecWebSocketExtensions, &value)) {
+ WebSocketExtensionParser parser;
+ parser.Parse(value);
+ if (parser.has_error()) {
+ // TODO(yhirano) Set appropriate failure message.
+ *failure_message =
+ "'Sec-WebSocket-Extensions' header value is "
+ "rejected by the parser: " +
+ value;
+ return false;
+ }
// TODO(ricea): Accept permessage-deflate with valid parameters.
+ *failure_message =
+ "Found an unsupported extension '" +
+ parser.extension().name() +
+ "' in 'Sec-WebSocket-Extensions' header";
return false;
}
return true;
@@ -267,6 +392,10 @@ void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
handshake_challenge_for_testing_.reset(new std::string(key));
}
+std::string WebSocketBasicHandshakeStream::GetFailureMessage() const {
+ return failure_message_;
+}
+
void WebSocketBasicHandshakeStream::ReadResponseHeadersCallback(
const CompletionCallback& callback,
int result) {
@@ -292,26 +421,30 @@ int WebSocketBasicHandshakeStream::ValidateResponse() {
// Other status codes are potentially risky (see the warnings in the
// WHATWG WebSocket API spec) and so are dropped by default.
default:
+ failure_message_ = base::StringPrintf("Unexpected status code: %d",
+ headers->response_code());
return ERR_INVALID_RESPONSE;
}
}
int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
const scoped_refptr<HttpResponseHeaders>& headers) {
- if (ValidateSingleTokenHeader(headers,
- websockets::kUpgrade,
- websockets::kWebSocketLowercase,
- false) &&
- ValidateSingleTokenHeader(headers,
- websockets::kSecWebSocketAccept,
- handshake_challenge_response_,
- true) &&
- headers->HasHeaderValue(HttpRequestHeaders::kConnection,
- websockets::kUpgrade) &&
- ValidateSubProtocol(headers, requested_sub_protocols_, &sub_protocol_) &&
- ValidateExtensions(headers, requested_extensions_, &extensions_)) {
+ if (ValidateUpgrade(headers.get(), &failure_message_) &&
+ ValidateSecWebSocketAccept(headers.get(),
+ handshake_challenge_response_,
+ &failure_message_) &&
+ ValidateConnection(headers.get(), &failure_message_) &&
+ ValidateSubProtocol(headers.get(),
+ requested_sub_protocols_,
+ &sub_protocol_,
+ &failure_message_) &&
+ ValidateExtensions(headers.get(),
+ requested_extensions_,
+ &extensions_,
+ &failure_message_)) {
return OK;
}
+ failure_message_ = "Error during WebSocket handshake: " + failure_message_;
return ERR_INVALID_RESPONSE;
}
« no previous file with comments | « net/websockets/websocket_basic_handshake_stream.h ('k') | net/websockets/websocket_channel.h » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698