Index: net/websockets/websocket_handshake.cc |
diff --git a/net/websockets/websocket_handshake.cc b/net/websockets/websocket_handshake.cc |
index c17ea344ccfc52f8d50d993c7c1031f3ca93abe3..6f660bc8b54b6154ce182c49d26aac61b7df4376 100644 |
--- a/net/websockets/websocket_handshake.cc |
+++ b/net/websockets/websocket_handshake.cc |
@@ -4,6 +4,11 @@ |
#include "net/websockets/websocket_handshake.h" |
+#include <algorithm> |
+#include <vector> |
+ |
+#include "base/md5.h" |
+#include "base/rand_util.h" |
#include "base/ref_counted.h" |
#include "base/string_util.h" |
#include "net/http/http_response_headers.h" |
@@ -14,19 +19,6 @@ namespace net { |
const int WebSocketHandshake::kWebSocketPort = 80; |
const int WebSocketHandshake::kSecureWebSocketPort = 443; |
-const char WebSocketHandshake::kServerHandshakeHeader[] = |
- "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"; |
-const size_t WebSocketHandshake::kServerHandshakeHeaderLength = |
- sizeof(kServerHandshakeHeader) - 1; |
- |
-const char WebSocketHandshake::kUpgradeHeader[] = "Upgrade: WebSocket\r\n"; |
-const size_t WebSocketHandshake::kUpgradeHeaderLength = |
- sizeof(kUpgradeHeader) - 1; |
- |
-const char WebSocketHandshake::kConnectionHeader[] = "Connection: Upgrade\r\n"; |
-const size_t WebSocketHandshake::kConnectionHeaderLength = |
- sizeof(kConnectionHeader) - 1; |
- |
WebSocketHandshake::WebSocketHandshake( |
const GURL& url, |
const std::string& origin, |
@@ -46,19 +38,94 @@ bool WebSocketHandshake::is_secure() const { |
return url_.SchemeIs("wss"); |
} |
-std::string WebSocketHandshake::CreateClientHandshakeMessage() const { |
+std::string WebSocketHandshake::CreateClientHandshakeMessage() { |
+ if (!parameter_.get()) { |
+ parameter_.reset(new Parameter); |
+ parameter_->GenerateKeys(); |
+ } |
std::string msg; |
+ |
+ // WebSocket protocol 4.1 Opening handshake. |
+ |
msg = "GET "; |
- msg += url_.path(); |
+ msg += GetResourceName(); |
+ msg += " HTTP/1.1\r\n"; |
+ |
+ std::vector<std::string> fields; |
+ |
+ fields.push_back("Upgrade: WebSocket"); |
+ fields.push_back("Connection: Upgrade"); |
+ |
+ fields.push_back("Host: " + GetHostFieldValue()); |
+ |
+ fields.push_back("Origin: " + GetOriginFieldValue()); |
+ |
+ if (!protocol_.empty()) |
+ fields.push_back("Sec-WebSocket-Protocol: " + protocol_); |
+ |
+ // TODO(ukai): Add cookie if necessary. |
+ |
+ fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1()); |
+ fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2()); |
+ |
+ std::random_shuffle(fields.begin(), fields.end()); |
+ |
+ for (size_t i = 0; i < fields.size(); i++) { |
+ msg += fields[i] + "\r\n"; |
+ } |
+ msg += "\r\n"; |
+ |
+ msg.append(parameter_->GetKey3()); |
+ return msg; |
+} |
+ |
+int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { |
+ mode_ = MODE_INCOMPLETE; |
+ int eoh = HttpUtil::LocateEndOfHeaders(data, len); |
+ if (eoh < 0) |
+ return -1; |
+ |
+ scoped_refptr<HttpResponseHeaders> headers( |
+ new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); |
+ |
+ if (headers->response_code() != 101) { |
+ mode_ = MODE_FAILED; |
+ DLOG(INFO) << "Bad response code: " << headers->response_code(); |
+ return eoh; |
+ } |
+ mode_ = MODE_NORMAL; |
+ if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) { |
+ DLOG(INFO) << "Process Headers failed: " |
+ << std::string(data, eoh); |
+ mode_ = MODE_FAILED; |
+ return eoh; |
+ } |
+ if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) { |
+ mode_ = MODE_INCOMPLETE; |
+ return -1; |
+ } |
+ uint8 expected[Parameter::kExpectedResponseSize]; |
+ parameter_->GetExpectedResponse(expected); |
+ if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) { |
+ mode_ = MODE_FAILED; |
+ return eoh + Parameter::kExpectedResponseSize; |
+ } |
+ mode_ = MODE_CONNECTED; |
+ return eoh + Parameter::kExpectedResponseSize; |
+} |
+ |
+std::string WebSocketHandshake::GetResourceName() const { |
+ std::string resource_name = url_.path(); |
if (url_.has_query()) { |
- msg += "?"; |
- msg += url_.query(); |
+ resource_name += "?"; |
+ resource_name += url_.query(); |
} |
- msg += " HTTP/1.1\r\n"; |
- msg += kUpgradeHeader; |
- msg += kConnectionHeader; |
- msg += "Host: "; |
- msg += StringToLowerASCII(url_.host()); |
+ return resource_name; |
+} |
+ |
+std::string WebSocketHandshake::GetHostFieldValue() const { |
+ // url_.host() is expected to be encoded in punnycode here. |
+ std::string host = StringToLowerASCII(url_.host()); |
if (url_.has_port()) { |
bool secure = is_secure(); |
int port = url_.EffectiveIntPort(); |
@@ -66,12 +133,14 @@ std::string WebSocketHandshake::CreateClientHandshakeMessage() const { |
port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || |
(secure && |
port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { |
- msg += ":"; |
- msg += IntToString(port); |
+ host += ":"; |
+ host += IntToString(port); |
} |
} |
- msg += "\r\n"; |
- msg += "Origin: "; |
+ return host; |
+} |
+ |
+std::string WebSocketHandshake::GetOriginFieldValue() const { |
// It's OK to lowercase the origin as the Origin header does not contain |
// the path or query portions, as per |
// http://tools.ietf.org/html/draft-abarth-origin-00. |
@@ -79,91 +148,13 @@ std::string WebSocketHandshake::CreateClientHandshakeMessage() const { |
// TODO(satorux): Should we trim the port portion here if it's 80 for |
// http:// or 443 for https:// ? Or can we assume it's done by the |
// client of the library? |
- msg += StringToLowerASCII(origin_); |
- msg += "\r\n"; |
- if (!protocol_.empty()) { |
- msg += "WebSocket-Protocol: "; |
- msg += protocol_; |
- msg += "\r\n"; |
- } |
- // TODO(ukai): Add cookie if necessary. |
- msg += "\r\n"; |
- return msg; |
+ return StringToLowerASCII(origin_); |
} |
-int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { |
- mode_ = MODE_INCOMPLETE; |
- if (len < kServerHandshakeHeaderLength) { |
- return -1; |
- } |
- if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) { |
- mode_ = MODE_NORMAL; |
- } else { |
- int eoh = HttpUtil::LocateEndOfHeaders(data, len); |
- if (eoh < 0) |
- return -1; |
- return eoh; |
- } |
- const char* p = data + kServerHandshakeHeaderLength; |
- const char* end = data + len + 1; |
- |
- if (mode_ == MODE_NORMAL) { |
- size_t header_size = end - p; |
- if (header_size < kUpgradeHeaderLength) |
- return -1; |
- if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) { |
- mode_ = MODE_FAILED; |
- DLOG(INFO) << "Bad Upgrade Header " |
- << std::string(p, kUpgradeHeaderLength); |
- return p - data; |
- } |
- p += kUpgradeHeaderLength; |
- header_size = end - p; |
- if (header_size < kConnectionHeaderLength) |
- return -1; |
- if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) { |
- mode_ = MODE_FAILED; |
- DLOG(INFO) << "Bad Connection Header " |
- << std::string(p, kConnectionHeaderLength); |
- return p - data; |
- } |
- p += kConnectionHeaderLength; |
- } |
- |
- int eoh = HttpUtil::LocateEndOfHeaders(data, len); |
- if (eoh == -1) |
- return eoh; |
- |
- scoped_refptr<HttpResponseHeaders> headers( |
- new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); |
- if (!ProcessHeaders(*headers)) { |
- DLOG(INFO) << "Process Headers failed: " |
- << std::string(data, eoh); |
- mode_ = MODE_FAILED; |
- } |
- switch (mode_) { |
- case MODE_NORMAL: |
- if (CheckResponseHeaders()) { |
- mode_ = MODE_CONNECTED; |
- } else { |
- mode_ = MODE_FAILED; |
- } |
- break; |
- default: |
- mode_ = MODE_FAILED; |
- break; |
- } |
- return eoh; |
-} |
- |
-// Gets the value of the specified header. |
-// It assures only one header of |name| in |headers|. |
-// Returns true iff single header of |name| is found in |headers| |
-// and |value| is filled with the value. |
-// Returns false otherwise. |
-static bool GetSingleHeader(const HttpResponseHeaders& headers, |
- const std::string& name, |
- std::string* value) { |
+/* static */ |
+bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers, |
+ const std::string& name, |
+ std::string* value) { |
std::string first_value; |
void* iter = NULL; |
if (!headers.EnumerateHeader(&iter, name, &first_value)) |
@@ -179,16 +170,25 @@ static bool GetSingleHeader(const HttpResponseHeaders& headers, |
} |
bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) { |
- if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_)) |
+ std::string value; |
+ if (!GetSingleHeader(headers, "upgrade", &value) || |
+ value != "WebSocket") |
+ return false; |
+ |
+ if (!GetSingleHeader(headers, "connection", &value) || |
+ !LowerCaseEqualsASCII(value, "upgrade")) |
+ return false; |
+ |
+ if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_)) |
return false; |
- if (!GetSingleHeader(headers, "websocket-location", &ws_location_)) |
+ if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_)) |
return false; |
// If |protocol_| is not specified by client, we don't care if there's |
// protocol field or not as specified in the spec. |
if (!protocol_.empty() |
- && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_)) |
+ && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_)) |
return false; |
return true; |
} |
@@ -204,6 +204,100 @@ bool WebSocketHandshake::CheckResponseHeaders() const { |
return true; |
} |
+namespace { |
+ |
+// unsigned int version of base::RandInt(). |
+// we can't use base::RandInt(), because max would be negative if it is |
+// represented as int, so DCHECK(min <= max) fails. |
+uint32 RandUint32(uint32 min, uint32 max) { |
+ DCHECK(min <= max); |
+ |
+ uint64 range = static_cast<int64>(max) - min + 1; |
+ uint64 number = base::RandUint64(); |
+ // TODO(ukai): fix to be uniform. |
+ // the distribution of the result of modulo will be biased. |
+ uint32 result = min + static_cast<uint32>(number % range); |
+ DCHECK(result >= min && result <= max); |
+ return result; |
+} |
+ |
+} |
+ |
+uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) = |
+ RandUint32; |
+uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39]; |
+WebSocketHandshake::Parameter::Parameter() |
+ : number_1_(0), number_2_(0) { |
+ if (randomCharacterInSecWebSocketKey[0] == '\0') { |
+ int i = 0; |
+ for (int ch = 0x21; ch <= 0x2F; ch++, i++) |
+ randomCharacterInSecWebSocketKey[i] = ch; |
+ for (int ch = 0x3A; ch <= 0x7E; ch++, i++) |
+ randomCharacterInSecWebSocketKey[i] = ch; |
+ } |
+} |
+ |
+WebSocketHandshake::Parameter::~Parameter() {} |
+ |
+void WebSocketHandshake::Parameter::GenerateKeys() { |
+ GenerateSecWebSocketKey(&number_1_, &key_1_); |
+ GenerateSecWebSocketKey(&number_2_, &key_2_); |
+ GenerateKey3(); |
+} |
+ |
+static void SetChallengeNumber(uint8* buf, uint32 number) { |
+ uint8* p = buf + 3; |
+ for (int i = 0; i < 4; i++) { |
+ *p = (uint8)(number & 0xFF); |
+ --p; |
+ number >>= 8; |
+ } |
+} |
+ |
+void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const { |
+ uint8 challenge[kExpectedResponseSize]; |
+ SetChallengeNumber(&challenge[0], number_1_); |
+ SetChallengeNumber(&challenge[4], number_2_); |
+ memcpy(&challenge[8], key_3_.data(), kKey3Size); |
+ MD5Digest digest; |
+ MD5Sum(challenge, kExpectedResponseSize, &digest); |
+ memcpy(expected, digest.a, kExpectedResponseSize); |
+} |
+ |
+/* static */ |
+void WebSocketHandshake::Parameter::SetRandomNumberGenerator( |
+ uint32 (*rand)(uint32 min, uint32 max)) { |
+ rand_ = rand; |
+} |
+ |
+void WebSocketHandshake::Parameter::GenerateSecWebSocketKey( |
+ uint32* number, std::string* key) { |
+ uint32 space = rand_(1, 12); |
+ uint32 max = 4294967295U / space; |
+ *number = rand_(0, max); |
+ uint32 product = *number * space; |
+ |
+ std::string s = StringPrintf("%010u", product); |
+ for (uint32 i = 0; i < space; i++) { |
+ int pos = rand_(1, s.length() - 1); |
+ s = s.substr(0, pos) + " " + s.substr(pos); |
+ } |
+ int n = rand_(1, 12); |
+ for (int i = 0; i < n; i++) { |
+ int pos = rand_(0, s.length()); |
+ int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1); |
+ s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) + |
+ s.substr(pos); |
+ } |
+ *key = s; |
+} |
+ |
+void WebSocketHandshake::Parameter::GenerateKey3() { |
+ key_3_.clear(); |
+ for (int i = 0; i < 8; i++) { |
+ key_3_.append(1, rand_(0, 255)); |
+ } |
+} |
} // namespace net |