| Index: net/websockets/websocket_handshake.cc
|
| diff --git a/net/websockets/websocket_handshake.cc b/net/websockets/websocket_handshake.cc
|
| index c17ea344ccfc52f8d50d993c7c1031f3ca93abe3..6f660bc8b54b6154ce182c49d26aac61b7df4376 100644
|
| --- a/net/websockets/websocket_handshake.cc
|
| +++ b/net/websockets/websocket_handshake.cc
|
| @@ -4,6 +4,11 @@
|
|
|
| #include "net/websockets/websocket_handshake.h"
|
|
|
| +#include <algorithm>
|
| +#include <vector>
|
| +
|
| +#include "base/md5.h"
|
| +#include "base/rand_util.h"
|
| #include "base/ref_counted.h"
|
| #include "base/string_util.h"
|
| #include "net/http/http_response_headers.h"
|
| @@ -14,19 +19,6 @@ namespace net {
|
| const int WebSocketHandshake::kWebSocketPort = 80;
|
| const int WebSocketHandshake::kSecureWebSocketPort = 443;
|
|
|
| -const char WebSocketHandshake::kServerHandshakeHeader[] =
|
| - "HTTP/1.1 101 Web Socket Protocol Handshake\r\n";
|
| -const size_t WebSocketHandshake::kServerHandshakeHeaderLength =
|
| - sizeof(kServerHandshakeHeader) - 1;
|
| -
|
| -const char WebSocketHandshake::kUpgradeHeader[] = "Upgrade: WebSocket\r\n";
|
| -const size_t WebSocketHandshake::kUpgradeHeaderLength =
|
| - sizeof(kUpgradeHeader) - 1;
|
| -
|
| -const char WebSocketHandshake::kConnectionHeader[] = "Connection: Upgrade\r\n";
|
| -const size_t WebSocketHandshake::kConnectionHeaderLength =
|
| - sizeof(kConnectionHeader) - 1;
|
| -
|
| WebSocketHandshake::WebSocketHandshake(
|
| const GURL& url,
|
| const std::string& origin,
|
| @@ -46,19 +38,94 @@ bool WebSocketHandshake::is_secure() const {
|
| return url_.SchemeIs("wss");
|
| }
|
|
|
| -std::string WebSocketHandshake::CreateClientHandshakeMessage() const {
|
| +std::string WebSocketHandshake::CreateClientHandshakeMessage() {
|
| + if (!parameter_.get()) {
|
| + parameter_.reset(new Parameter);
|
| + parameter_->GenerateKeys();
|
| + }
|
| std::string msg;
|
| +
|
| + // WebSocket protocol 4.1 Opening handshake.
|
| +
|
| msg = "GET ";
|
| - msg += url_.path();
|
| + msg += GetResourceName();
|
| + msg += " HTTP/1.1\r\n";
|
| +
|
| + std::vector<std::string> fields;
|
| +
|
| + fields.push_back("Upgrade: WebSocket");
|
| + fields.push_back("Connection: Upgrade");
|
| +
|
| + fields.push_back("Host: " + GetHostFieldValue());
|
| +
|
| + fields.push_back("Origin: " + GetOriginFieldValue());
|
| +
|
| + if (!protocol_.empty())
|
| + fields.push_back("Sec-WebSocket-Protocol: " + protocol_);
|
| +
|
| + // TODO(ukai): Add cookie if necessary.
|
| +
|
| + fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1());
|
| + fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2());
|
| +
|
| + std::random_shuffle(fields.begin(), fields.end());
|
| +
|
| + for (size_t i = 0; i < fields.size(); i++) {
|
| + msg += fields[i] + "\r\n";
|
| + }
|
| + msg += "\r\n";
|
| +
|
| + msg.append(parameter_->GetKey3());
|
| + return msg;
|
| +}
|
| +
|
| +int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) {
|
| + mode_ = MODE_INCOMPLETE;
|
| + int eoh = HttpUtil::LocateEndOfHeaders(data, len);
|
| + if (eoh < 0)
|
| + return -1;
|
| +
|
| + scoped_refptr<HttpResponseHeaders> headers(
|
| + new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
|
| +
|
| + if (headers->response_code() != 101) {
|
| + mode_ = MODE_FAILED;
|
| + DLOG(INFO) << "Bad response code: " << headers->response_code();
|
| + return eoh;
|
| + }
|
| + mode_ = MODE_NORMAL;
|
| + if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) {
|
| + DLOG(INFO) << "Process Headers failed: "
|
| + << std::string(data, eoh);
|
| + mode_ = MODE_FAILED;
|
| + return eoh;
|
| + }
|
| + if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) {
|
| + mode_ = MODE_INCOMPLETE;
|
| + return -1;
|
| + }
|
| + uint8 expected[Parameter::kExpectedResponseSize];
|
| + parameter_->GetExpectedResponse(expected);
|
| + if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) {
|
| + mode_ = MODE_FAILED;
|
| + return eoh + Parameter::kExpectedResponseSize;
|
| + }
|
| + mode_ = MODE_CONNECTED;
|
| + return eoh + Parameter::kExpectedResponseSize;
|
| +}
|
| +
|
| +std::string WebSocketHandshake::GetResourceName() const {
|
| + std::string resource_name = url_.path();
|
| if (url_.has_query()) {
|
| - msg += "?";
|
| - msg += url_.query();
|
| + resource_name += "?";
|
| + resource_name += url_.query();
|
| }
|
| - msg += " HTTP/1.1\r\n";
|
| - msg += kUpgradeHeader;
|
| - msg += kConnectionHeader;
|
| - msg += "Host: ";
|
| - msg += StringToLowerASCII(url_.host());
|
| + return resource_name;
|
| +}
|
| +
|
| +std::string WebSocketHandshake::GetHostFieldValue() const {
|
| + // url_.host() is expected to be encoded in punnycode here.
|
| + std::string host = StringToLowerASCII(url_.host());
|
| if (url_.has_port()) {
|
| bool secure = is_secure();
|
| int port = url_.EffectiveIntPort();
|
| @@ -66,12 +133,14 @@ std::string WebSocketHandshake::CreateClientHandshakeMessage() const {
|
| port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) ||
|
| (secure &&
|
| port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) {
|
| - msg += ":";
|
| - msg += IntToString(port);
|
| + host += ":";
|
| + host += IntToString(port);
|
| }
|
| }
|
| - msg += "\r\n";
|
| - msg += "Origin: ";
|
| + return host;
|
| +}
|
| +
|
| +std::string WebSocketHandshake::GetOriginFieldValue() const {
|
| // It's OK to lowercase the origin as the Origin header does not contain
|
| // the path or query portions, as per
|
| // http://tools.ietf.org/html/draft-abarth-origin-00.
|
| @@ -79,91 +148,13 @@ std::string WebSocketHandshake::CreateClientHandshakeMessage() const {
|
| // TODO(satorux): Should we trim the port portion here if it's 80 for
|
| // http:// or 443 for https:// ? Or can we assume it's done by the
|
| // client of the library?
|
| - msg += StringToLowerASCII(origin_);
|
| - msg += "\r\n";
|
| - if (!protocol_.empty()) {
|
| - msg += "WebSocket-Protocol: ";
|
| - msg += protocol_;
|
| - msg += "\r\n";
|
| - }
|
| - // TODO(ukai): Add cookie if necessary.
|
| - msg += "\r\n";
|
| - return msg;
|
| + return StringToLowerASCII(origin_);
|
| }
|
|
|
| -int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) {
|
| - mode_ = MODE_INCOMPLETE;
|
| - if (len < kServerHandshakeHeaderLength) {
|
| - return -1;
|
| - }
|
| - if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) {
|
| - mode_ = MODE_NORMAL;
|
| - } else {
|
| - int eoh = HttpUtil::LocateEndOfHeaders(data, len);
|
| - if (eoh < 0)
|
| - return -1;
|
| - return eoh;
|
| - }
|
| - const char* p = data + kServerHandshakeHeaderLength;
|
| - const char* end = data + len + 1;
|
| -
|
| - if (mode_ == MODE_NORMAL) {
|
| - size_t header_size = end - p;
|
| - if (header_size < kUpgradeHeaderLength)
|
| - return -1;
|
| - if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) {
|
| - mode_ = MODE_FAILED;
|
| - DLOG(INFO) << "Bad Upgrade Header "
|
| - << std::string(p, kUpgradeHeaderLength);
|
| - return p - data;
|
| - }
|
| - p += kUpgradeHeaderLength;
|
| - header_size = end - p;
|
| - if (header_size < kConnectionHeaderLength)
|
| - return -1;
|
| - if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) {
|
| - mode_ = MODE_FAILED;
|
| - DLOG(INFO) << "Bad Connection Header "
|
| - << std::string(p, kConnectionHeaderLength);
|
| - return p - data;
|
| - }
|
| - p += kConnectionHeaderLength;
|
| - }
|
| -
|
| - int eoh = HttpUtil::LocateEndOfHeaders(data, len);
|
| - if (eoh == -1)
|
| - return eoh;
|
| -
|
| - scoped_refptr<HttpResponseHeaders> headers(
|
| - new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh)));
|
| - if (!ProcessHeaders(*headers)) {
|
| - DLOG(INFO) << "Process Headers failed: "
|
| - << std::string(data, eoh);
|
| - mode_ = MODE_FAILED;
|
| - }
|
| - switch (mode_) {
|
| - case MODE_NORMAL:
|
| - if (CheckResponseHeaders()) {
|
| - mode_ = MODE_CONNECTED;
|
| - } else {
|
| - mode_ = MODE_FAILED;
|
| - }
|
| - break;
|
| - default:
|
| - mode_ = MODE_FAILED;
|
| - break;
|
| - }
|
| - return eoh;
|
| -}
|
| -
|
| -// Gets the value of the specified header.
|
| -// It assures only one header of |name| in |headers|.
|
| -// Returns true iff single header of |name| is found in |headers|
|
| -// and |value| is filled with the value.
|
| -// Returns false otherwise.
|
| -static bool GetSingleHeader(const HttpResponseHeaders& headers,
|
| - const std::string& name,
|
| - std::string* value) {
|
| +/* static */
|
| +bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers,
|
| + const std::string& name,
|
| + std::string* value) {
|
| std::string first_value;
|
| void* iter = NULL;
|
| if (!headers.EnumerateHeader(&iter, name, &first_value))
|
| @@ -179,16 +170,25 @@ static bool GetSingleHeader(const HttpResponseHeaders& headers,
|
| }
|
|
|
| bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) {
|
| - if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_))
|
| + std::string value;
|
| + if (!GetSingleHeader(headers, "upgrade", &value) ||
|
| + value != "WebSocket")
|
| + return false;
|
| +
|
| + if (!GetSingleHeader(headers, "connection", &value) ||
|
| + !LowerCaseEqualsASCII(value, "upgrade"))
|
| + return false;
|
| +
|
| + if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_))
|
| return false;
|
|
|
| - if (!GetSingleHeader(headers, "websocket-location", &ws_location_))
|
| + if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_))
|
| return false;
|
|
|
| // If |protocol_| is not specified by client, we don't care if there's
|
| // protocol field or not as specified in the spec.
|
| if (!protocol_.empty()
|
| - && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_))
|
| + && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_))
|
| return false;
|
| return true;
|
| }
|
| @@ -204,6 +204,100 @@ bool WebSocketHandshake::CheckResponseHeaders() const {
|
| return true;
|
| }
|
|
|
| +namespace {
|
| +
|
| +// unsigned int version of base::RandInt().
|
| +// we can't use base::RandInt(), because max would be negative if it is
|
| +// represented as int, so DCHECK(min <= max) fails.
|
| +uint32 RandUint32(uint32 min, uint32 max) {
|
| + DCHECK(min <= max);
|
| +
|
| + uint64 range = static_cast<int64>(max) - min + 1;
|
| + uint64 number = base::RandUint64();
|
| + // TODO(ukai): fix to be uniform.
|
| + // the distribution of the result of modulo will be biased.
|
| + uint32 result = min + static_cast<uint32>(number % range);
|
| + DCHECK(result >= min && result <= max);
|
| + return result;
|
| +}
|
| +
|
| +}
|
| +
|
| +uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) =
|
| + RandUint32;
|
| +uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39];
|
|
|
| +WebSocketHandshake::Parameter::Parameter()
|
| + : number_1_(0), number_2_(0) {
|
| + if (randomCharacterInSecWebSocketKey[0] == '\0') {
|
| + int i = 0;
|
| + for (int ch = 0x21; ch <= 0x2F; ch++, i++)
|
| + randomCharacterInSecWebSocketKey[i] = ch;
|
| + for (int ch = 0x3A; ch <= 0x7E; ch++, i++)
|
| + randomCharacterInSecWebSocketKey[i] = ch;
|
| + }
|
| +}
|
| +
|
| +WebSocketHandshake::Parameter::~Parameter() {}
|
| +
|
| +void WebSocketHandshake::Parameter::GenerateKeys() {
|
| + GenerateSecWebSocketKey(&number_1_, &key_1_);
|
| + GenerateSecWebSocketKey(&number_2_, &key_2_);
|
| + GenerateKey3();
|
| +}
|
| +
|
| +static void SetChallengeNumber(uint8* buf, uint32 number) {
|
| + uint8* p = buf + 3;
|
| + for (int i = 0; i < 4; i++) {
|
| + *p = (uint8)(number & 0xFF);
|
| + --p;
|
| + number >>= 8;
|
| + }
|
| +}
|
| +
|
| +void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const {
|
| + uint8 challenge[kExpectedResponseSize];
|
| + SetChallengeNumber(&challenge[0], number_1_);
|
| + SetChallengeNumber(&challenge[4], number_2_);
|
| + memcpy(&challenge[8], key_3_.data(), kKey3Size);
|
| + MD5Digest digest;
|
| + MD5Sum(challenge, kExpectedResponseSize, &digest);
|
| + memcpy(expected, digest.a, kExpectedResponseSize);
|
| +}
|
| +
|
| +/* static */
|
| +void WebSocketHandshake::Parameter::SetRandomNumberGenerator(
|
| + uint32 (*rand)(uint32 min, uint32 max)) {
|
| + rand_ = rand;
|
| +}
|
| +
|
| +void WebSocketHandshake::Parameter::GenerateSecWebSocketKey(
|
| + uint32* number, std::string* key) {
|
| + uint32 space = rand_(1, 12);
|
| + uint32 max = 4294967295U / space;
|
| + *number = rand_(0, max);
|
| + uint32 product = *number * space;
|
| +
|
| + std::string s = StringPrintf("%010u", product);
|
| + for (uint32 i = 0; i < space; i++) {
|
| + int pos = rand_(1, s.length() - 1);
|
| + s = s.substr(0, pos) + " " + s.substr(pos);
|
| + }
|
| + int n = rand_(1, 12);
|
| + for (int i = 0; i < n; i++) {
|
| + int pos = rand_(0, s.length());
|
| + int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1);
|
| + s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) +
|
| + s.substr(pos);
|
| + }
|
| + *key = s;
|
| +}
|
| +
|
| +void WebSocketHandshake::Parameter::GenerateKey3() {
|
| + key_3_.clear();
|
| + for (int i = 0; i < 8; i++) {
|
| + key_3_.append(1, rand_(0, 255));
|
| + }
|
| +}
|
|
|
| } // namespace net
|
|
|