| OLD | NEW |
| (Empty) |
| 1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. | |
| 2 // Use of this source code is governed by a BSD-style license that can be | |
| 3 // found in the LICENSE file. | |
| 4 | |
| 5 #include "net/websockets/websocket_handshake.h" | |
| 6 | |
| 7 #include <algorithm> | |
| 8 #include <vector> | |
| 9 | |
| 10 #include "base/logging.h" | |
| 11 #include "base/md5.h" | |
| 12 #include "base/memory/ref_counted.h" | |
| 13 #include "base/rand_util.h" | |
| 14 #include "base/string_number_conversions.h" | |
| 15 #include "base/string_util.h" | |
| 16 #include "base/stringprintf.h" | |
| 17 #include "net/http/http_response_headers.h" | |
| 18 #include "net/http/http_util.h" | |
| 19 | |
| 20 namespace net { | |
| 21 | |
| 22 const int WebSocketHandshake::kWebSocketPort = 80; | |
| 23 const int WebSocketHandshake::kSecureWebSocketPort = 443; | |
| 24 | |
| 25 WebSocketHandshake::WebSocketHandshake( | |
| 26 const GURL& url, | |
| 27 const std::string& origin, | |
| 28 const std::string& location, | |
| 29 const std::string& protocol) | |
| 30 : url_(url), | |
| 31 origin_(origin), | |
| 32 location_(location), | |
| 33 protocol_(protocol), | |
| 34 mode_(MODE_INCOMPLETE) { | |
| 35 } | |
| 36 | |
| 37 WebSocketHandshake::~WebSocketHandshake() { | |
| 38 } | |
| 39 | |
| 40 bool WebSocketHandshake::is_secure() const { | |
| 41 return url_.SchemeIs("wss"); | |
| 42 } | |
| 43 | |
| 44 std::string WebSocketHandshake::CreateClientHandshakeMessage() { | |
| 45 if (!parameter_.get()) { | |
| 46 parameter_.reset(new Parameter); | |
| 47 parameter_->GenerateKeys(); | |
| 48 } | |
| 49 std::string msg; | |
| 50 | |
| 51 // WebSocket protocol 4.1 Opening handshake. | |
| 52 | |
| 53 msg = "GET "; | |
| 54 msg += GetResourceName(); | |
| 55 msg += " HTTP/1.1\r\n"; | |
| 56 | |
| 57 std::vector<std::string> fields; | |
| 58 | |
| 59 fields.push_back("Upgrade: WebSocket"); | |
| 60 fields.push_back("Connection: Upgrade"); | |
| 61 | |
| 62 fields.push_back("Host: " + GetHostFieldValue()); | |
| 63 | |
| 64 fields.push_back("Origin: " + GetOriginFieldValue()); | |
| 65 | |
| 66 if (!protocol_.empty()) | |
| 67 fields.push_back("Sec-WebSocket-Protocol: " + protocol_); | |
| 68 | |
| 69 // TODO(ukai): Add cookie if necessary. | |
| 70 | |
| 71 fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1()); | |
| 72 fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2()); | |
| 73 | |
| 74 std::random_shuffle(fields.begin(), fields.end(), base::RandGenerator); | |
| 75 | |
| 76 for (size_t i = 0; i < fields.size(); i++) { | |
| 77 msg += fields[i] + "\r\n"; | |
| 78 } | |
| 79 msg += "\r\n"; | |
| 80 | |
| 81 msg.append(parameter_->GetKey3()); | |
| 82 return msg; | |
| 83 } | |
| 84 | |
| 85 int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { | |
| 86 mode_ = MODE_INCOMPLETE; | |
| 87 int eoh = HttpUtil::LocateEndOfHeaders(data, len); | |
| 88 if (eoh < 0) | |
| 89 return -1; | |
| 90 | |
| 91 scoped_refptr<HttpResponseHeaders> headers( | |
| 92 new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); | |
| 93 | |
| 94 if (headers->response_code() != 101) { | |
| 95 mode_ = MODE_FAILED; | |
| 96 DVLOG(1) << "Bad response code: " << headers->response_code(); | |
| 97 return eoh; | |
| 98 } | |
| 99 mode_ = MODE_NORMAL; | |
| 100 if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) { | |
| 101 DVLOG(1) << "Process Headers failed: " << std::string(data, eoh); | |
| 102 mode_ = MODE_FAILED; | |
| 103 return eoh; | |
| 104 } | |
| 105 if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) { | |
| 106 mode_ = MODE_INCOMPLETE; | |
| 107 return -1; | |
| 108 } | |
| 109 uint8 expected[Parameter::kExpectedResponseSize]; | |
| 110 parameter_->GetExpectedResponse(expected); | |
| 111 if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) { | |
| 112 mode_ = MODE_FAILED; | |
| 113 return eoh + Parameter::kExpectedResponseSize; | |
| 114 } | |
| 115 mode_ = MODE_CONNECTED; | |
| 116 return eoh + Parameter::kExpectedResponseSize; | |
| 117 } | |
| 118 | |
| 119 std::string WebSocketHandshake::GetResourceName() const { | |
| 120 std::string resource_name = url_.path(); | |
| 121 if (url_.has_query()) { | |
| 122 resource_name += "?"; | |
| 123 resource_name += url_.query(); | |
| 124 } | |
| 125 return resource_name; | |
| 126 } | |
| 127 | |
| 128 std::string WebSocketHandshake::GetHostFieldValue() const { | |
| 129 // url_.host() is expected to be encoded in punnycode here. | |
| 130 std::string host = StringToLowerASCII(url_.host()); | |
| 131 if (url_.has_port()) { | |
| 132 bool secure = is_secure(); | |
| 133 int port = url_.EffectiveIntPort(); | |
| 134 if ((!secure && | |
| 135 port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || | |
| 136 (secure && | |
| 137 port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { | |
| 138 host += ":"; | |
| 139 host += base::IntToString(port); | |
| 140 } | |
| 141 } | |
| 142 return host; | |
| 143 } | |
| 144 | |
| 145 std::string WebSocketHandshake::GetOriginFieldValue() const { | |
| 146 // It's OK to lowercase the origin as the Origin header does not contain | |
| 147 // the path or query portions, as per | |
| 148 // http://tools.ietf.org/html/draft-abarth-origin-00. | |
| 149 // | |
| 150 // TODO(satorux): Should we trim the port portion here if it's 80 for | |
| 151 // http:// or 443 for https:// ? Or can we assume it's done by the | |
| 152 // client of the library? | |
| 153 return StringToLowerASCII(origin_); | |
| 154 } | |
| 155 | |
| 156 /* static */ | |
| 157 bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers, | |
| 158 const std::string& name, | |
| 159 std::string* value) { | |
| 160 std::string first_value; | |
| 161 void* iter = NULL; | |
| 162 if (!headers.EnumerateHeader(&iter, name, &first_value)) | |
| 163 return false; | |
| 164 | |
| 165 // Checks no more |name| found in |headers|. | |
| 166 // Second call of EnumerateHeader() must return false. | |
| 167 std::string second_value; | |
| 168 if (headers.EnumerateHeader(&iter, name, &second_value)) | |
| 169 return false; | |
| 170 *value = first_value; | |
| 171 return true; | |
| 172 } | |
| 173 | |
| 174 bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) { | |
| 175 std::string value; | |
| 176 if (!GetSingleHeader(headers, "upgrade", &value) || | |
| 177 value != "WebSocket") | |
| 178 return false; | |
| 179 | |
| 180 if (!GetSingleHeader(headers, "connection", &value) || | |
| 181 !LowerCaseEqualsASCII(value, "upgrade")) | |
| 182 return false; | |
| 183 | |
| 184 if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_)) | |
| 185 return false; | |
| 186 | |
| 187 if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_)) | |
| 188 return false; | |
| 189 | |
| 190 // If |protocol_| is not specified by client, we don't care if there's | |
| 191 // protocol field or not as specified in the spec. | |
| 192 if (!protocol_.empty() | |
| 193 && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_)) | |
| 194 return false; | |
| 195 return true; | |
| 196 } | |
| 197 | |
| 198 bool WebSocketHandshake::CheckResponseHeaders() const { | |
| 199 DCHECK(mode_ == MODE_NORMAL); | |
| 200 if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str())) | |
| 201 return false; | |
| 202 if (location_ != ws_location_) | |
| 203 return false; | |
| 204 if (!protocol_.empty() && protocol_ != ws_protocol_) | |
| 205 return false; | |
| 206 return true; | |
| 207 } | |
| 208 | |
| 209 namespace { | |
| 210 | |
| 211 // unsigned int version of base::RandInt(). | |
| 212 // we can't use base::RandInt(), because max would be negative if it is | |
| 213 // represented as int, so DCHECK(min <= max) fails. | |
| 214 uint32 RandUint32(uint32 min, uint32 max) { | |
| 215 DCHECK(min <= max); | |
| 216 | |
| 217 uint64 range = static_cast<int64>(max) - min + 1; | |
| 218 uint64 number = base::RandGenerator(range); | |
| 219 uint32 result = min + static_cast<uint32>(number); | |
| 220 DCHECK(result >= min && result <= max); | |
| 221 return result; | |
| 222 } | |
| 223 | |
| 224 } | |
| 225 | |
| 226 uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) = | |
| 227 RandUint32; | |
| 228 uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39]; | |
| 229 | |
| 230 WebSocketHandshake::Parameter::Parameter() | |
| 231 : number_1_(0), number_2_(0) { | |
| 232 if (randomCharacterInSecWebSocketKey[0] == '\0') { | |
| 233 int i = 0; | |
| 234 for (int ch = 0x21; ch <= 0x2F; ch++, i++) | |
| 235 randomCharacterInSecWebSocketKey[i] = ch; | |
| 236 for (int ch = 0x3A; ch <= 0x7E; ch++, i++) | |
| 237 randomCharacterInSecWebSocketKey[i] = ch; | |
| 238 } | |
| 239 } | |
| 240 | |
| 241 WebSocketHandshake::Parameter::~Parameter() {} | |
| 242 | |
| 243 void WebSocketHandshake::Parameter::GenerateKeys() { | |
| 244 GenerateSecWebSocketKey(&number_1_, &key_1_); | |
| 245 GenerateSecWebSocketKey(&number_2_, &key_2_); | |
| 246 GenerateKey3(); | |
| 247 } | |
| 248 | |
| 249 static void SetChallengeNumber(uint8* buf, uint32 number) { | |
| 250 uint8* p = buf + 3; | |
| 251 for (int i = 0; i < 4; i++) { | |
| 252 *p = (uint8)(number & 0xFF); | |
| 253 --p; | |
| 254 number >>= 8; | |
| 255 } | |
| 256 } | |
| 257 | |
| 258 void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const { | |
| 259 uint8 challenge[kExpectedResponseSize]; | |
| 260 SetChallengeNumber(&challenge[0], number_1_); | |
| 261 SetChallengeNumber(&challenge[4], number_2_); | |
| 262 memcpy(&challenge[8], key_3_.data(), kKey3Size); | |
| 263 base::MD5Digest digest; | |
| 264 base::MD5Sum(challenge, kExpectedResponseSize, &digest); | |
| 265 memcpy(expected, digest.a, kExpectedResponseSize); | |
| 266 } | |
| 267 | |
| 268 /* static */ | |
| 269 void WebSocketHandshake::Parameter::SetRandomNumberGenerator( | |
| 270 uint32 (*rand)(uint32 min, uint32 max)) { | |
| 271 rand_ = rand; | |
| 272 } | |
| 273 | |
| 274 void WebSocketHandshake::Parameter::GenerateSecWebSocketKey( | |
| 275 uint32* number, std::string* key) { | |
| 276 uint32 space = rand_(1, 12); | |
| 277 uint32 max = 4294967295U / space; | |
| 278 *number = rand_(0, max); | |
| 279 uint32 product = *number * space; | |
| 280 | |
| 281 std::string s = base::StringPrintf("%u", product); | |
| 282 int n = rand_(1, 12); | |
| 283 for (int i = 0; i < n; i++) { | |
| 284 int pos = rand_(0, s.length()); | |
| 285 int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1); | |
| 286 s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) + | |
| 287 s.substr(pos); | |
| 288 } | |
| 289 for (uint32 i = 0; i < space; i++) { | |
| 290 int pos = rand_(1, s.length() - 1); | |
| 291 s = s.substr(0, pos) + " " + s.substr(pos); | |
| 292 } | |
| 293 *key = s; | |
| 294 } | |
| 295 | |
| 296 void WebSocketHandshake::Parameter::GenerateKey3() { | |
| 297 key_3_.clear(); | |
| 298 for (int i = 0; i < 8; i++) { | |
| 299 key_3_.append(1, rand_(0, 255)); | |
| 300 } | |
| 301 } | |
| 302 | |
| 303 } // namespace net | |
| OLD | NEW |