OLD | NEW |
1 // Copyright (c) 2010 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2010 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/websockets/websocket_handshake.h" | 5 #include "net/websockets/websocket_handshake.h" |
6 | 6 |
| 7 #include <algorithm> |
| 8 #include <vector> |
| 9 |
| 10 #include "base/md5.h" |
| 11 #include "base/rand_util.h" |
7 #include "base/ref_counted.h" | 12 #include "base/ref_counted.h" |
8 #include "base/string_util.h" | 13 #include "base/string_util.h" |
9 #include "net/http/http_response_headers.h" | 14 #include "net/http/http_response_headers.h" |
10 #include "net/http/http_util.h" | 15 #include "net/http/http_util.h" |
11 | 16 |
12 namespace net { | 17 namespace net { |
13 | 18 |
14 const int WebSocketHandshake::kWebSocketPort = 80; | 19 const int WebSocketHandshake::kWebSocketPort = 80; |
15 const int WebSocketHandshake::kSecureWebSocketPort = 443; | 20 const int WebSocketHandshake::kSecureWebSocketPort = 443; |
16 | 21 |
17 const char WebSocketHandshake::kServerHandshakeHeader[] = | |
18 "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"; | |
19 const size_t WebSocketHandshake::kServerHandshakeHeaderLength = | |
20 sizeof(kServerHandshakeHeader) - 1; | |
21 | |
22 const char WebSocketHandshake::kUpgradeHeader[] = "Upgrade: WebSocket\r\n"; | |
23 const size_t WebSocketHandshake::kUpgradeHeaderLength = | |
24 sizeof(kUpgradeHeader) - 1; | |
25 | |
26 const char WebSocketHandshake::kConnectionHeader[] = "Connection: Upgrade\r\n"; | |
27 const size_t WebSocketHandshake::kConnectionHeaderLength = | |
28 sizeof(kConnectionHeader) - 1; | |
29 | |
30 WebSocketHandshake::WebSocketHandshake( | 22 WebSocketHandshake::WebSocketHandshake( |
31 const GURL& url, | 23 const GURL& url, |
32 const std::string& origin, | 24 const std::string& origin, |
33 const std::string& location, | 25 const std::string& location, |
34 const std::string& protocol) | 26 const std::string& protocol) |
35 : url_(url), | 27 : url_(url), |
36 origin_(origin), | 28 origin_(origin), |
37 location_(location), | 29 location_(location), |
38 protocol_(protocol), | 30 protocol_(protocol), |
39 mode_(MODE_INCOMPLETE) { | 31 mode_(MODE_INCOMPLETE) { |
40 } | 32 } |
41 | 33 |
42 WebSocketHandshake::~WebSocketHandshake() { | 34 WebSocketHandshake::~WebSocketHandshake() { |
43 } | 35 } |
44 | 36 |
45 bool WebSocketHandshake::is_secure() const { | 37 bool WebSocketHandshake::is_secure() const { |
46 return url_.SchemeIs("wss"); | 38 return url_.SchemeIs("wss"); |
47 } | 39 } |
48 | 40 |
49 std::string WebSocketHandshake::CreateClientHandshakeMessage() const { | 41 std::string WebSocketHandshake::CreateClientHandshakeMessage() { |
| 42 if (!parameter_.get()) { |
| 43 parameter_.reset(new Parameter); |
| 44 parameter_->GenerateKeys(); |
| 45 } |
50 std::string msg; | 46 std::string msg; |
| 47 |
| 48 // WebSocket protocol 4.1 Opening handshake. |
| 49 |
51 msg = "GET "; | 50 msg = "GET "; |
52 msg += url_.path(); | 51 msg += GetResourceName(); |
| 52 msg += " HTTP/1.1\r\n"; |
| 53 |
| 54 std::vector<std::string> fields; |
| 55 |
| 56 fields.push_back("Upgrade: WebSocket"); |
| 57 fields.push_back("Connection: Upgrade"); |
| 58 |
| 59 fields.push_back("Host: " + GetHostFieldValue()); |
| 60 |
| 61 fields.push_back("Origin: " + GetOriginFieldValue()); |
| 62 |
| 63 if (!protocol_.empty()) |
| 64 fields.push_back("Sec-WebSocket-Protocol: " + protocol_); |
| 65 |
| 66 // TODO(ukai): Add cookie if necessary. |
| 67 |
| 68 fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1()); |
| 69 fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2()); |
| 70 |
| 71 std::random_shuffle(fields.begin(), fields.end()); |
| 72 |
| 73 for (size_t i = 0; i < fields.size(); i++) { |
| 74 msg += fields[i] + "\r\n"; |
| 75 } |
| 76 msg += "\r\n"; |
| 77 |
| 78 msg.append(parameter_->GetKey3()); |
| 79 return msg; |
| 80 } |
| 81 |
| 82 int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { |
| 83 mode_ = MODE_INCOMPLETE; |
| 84 int eoh = HttpUtil::LocateEndOfHeaders(data, len); |
| 85 if (eoh < 0) |
| 86 return -1; |
| 87 |
| 88 scoped_refptr<HttpResponseHeaders> headers( |
| 89 new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); |
| 90 |
| 91 if (headers->response_code() != 101) { |
| 92 mode_ = MODE_FAILED; |
| 93 DLOG(INFO) << "Bad response code: " << headers->response_code(); |
| 94 return eoh; |
| 95 } |
| 96 mode_ = MODE_NORMAL; |
| 97 if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) { |
| 98 DLOG(INFO) << "Process Headers failed: " |
| 99 << std::string(data, eoh); |
| 100 mode_ = MODE_FAILED; |
| 101 return eoh; |
| 102 } |
| 103 if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) { |
| 104 mode_ = MODE_INCOMPLETE; |
| 105 return -1; |
| 106 } |
| 107 uint8 expected[Parameter::kExpectedResponseSize]; |
| 108 parameter_->GetExpectedResponse(expected); |
| 109 if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) { |
| 110 mode_ = MODE_FAILED; |
| 111 return eoh + Parameter::kExpectedResponseSize; |
| 112 } |
| 113 mode_ = MODE_CONNECTED; |
| 114 return eoh + Parameter::kExpectedResponseSize; |
| 115 } |
| 116 |
| 117 std::string WebSocketHandshake::GetResourceName() const { |
| 118 std::string resource_name = url_.path(); |
53 if (url_.has_query()) { | 119 if (url_.has_query()) { |
54 msg += "?"; | 120 resource_name += "?"; |
55 msg += url_.query(); | 121 resource_name += url_.query(); |
56 } | 122 } |
57 msg += " HTTP/1.1\r\n"; | 123 return resource_name; |
58 msg += kUpgradeHeader; | 124 } |
59 msg += kConnectionHeader; | 125 |
60 msg += "Host: "; | 126 std::string WebSocketHandshake::GetHostFieldValue() const { |
61 msg += StringToLowerASCII(url_.host()); | 127 // url_.host() is expected to be encoded in punnycode here. |
| 128 std::string host = StringToLowerASCII(url_.host()); |
62 if (url_.has_port()) { | 129 if (url_.has_port()) { |
63 bool secure = is_secure(); | 130 bool secure = is_secure(); |
64 int port = url_.EffectiveIntPort(); | 131 int port = url_.EffectiveIntPort(); |
65 if ((!secure && | 132 if ((!secure && |
66 port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || | 133 port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || |
67 (secure && | 134 (secure && |
68 port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { | 135 port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { |
69 msg += ":"; | 136 host += ":"; |
70 msg += IntToString(port); | 137 host += IntToString(port); |
71 } | 138 } |
72 } | 139 } |
73 msg += "\r\n"; | 140 return host; |
74 msg += "Origin: "; | 141 } |
| 142 |
| 143 std::string WebSocketHandshake::GetOriginFieldValue() const { |
75 // It's OK to lowercase the origin as the Origin header does not contain | 144 // It's OK to lowercase the origin as the Origin header does not contain |
76 // the path or query portions, as per | 145 // the path or query portions, as per |
77 // http://tools.ietf.org/html/draft-abarth-origin-00. | 146 // http://tools.ietf.org/html/draft-abarth-origin-00. |
78 // | 147 // |
79 // TODO(satorux): Should we trim the port portion here if it's 80 for | 148 // TODO(satorux): Should we trim the port portion here if it's 80 for |
80 // http:// or 443 for https:// ? Or can we assume it's done by the | 149 // http:// or 443 for https:// ? Or can we assume it's done by the |
81 // client of the library? | 150 // client of the library? |
82 msg += StringToLowerASCII(origin_); | 151 return StringToLowerASCII(origin_); |
83 msg += "\r\n"; | |
84 if (!protocol_.empty()) { | |
85 msg += "WebSocket-Protocol: "; | |
86 msg += protocol_; | |
87 msg += "\r\n"; | |
88 } | |
89 // TODO(ukai): Add cookie if necessary. | |
90 msg += "\r\n"; | |
91 return msg; | |
92 } | 152 } |
93 | 153 |
94 int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { | 154 /* static */ |
95 mode_ = MODE_INCOMPLETE; | 155 bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers, |
96 if (len < kServerHandshakeHeaderLength) { | 156 const std::string& name, |
97 return -1; | 157 std::string* value) { |
98 } | |
99 if (!memcmp(data, kServerHandshakeHeader, kServerHandshakeHeaderLength)) { | |
100 mode_ = MODE_NORMAL; | |
101 } else { | |
102 int eoh = HttpUtil::LocateEndOfHeaders(data, len); | |
103 if (eoh < 0) | |
104 return -1; | |
105 return eoh; | |
106 } | |
107 const char* p = data + kServerHandshakeHeaderLength; | |
108 const char* end = data + len + 1; | |
109 | |
110 if (mode_ == MODE_NORMAL) { | |
111 size_t header_size = end - p; | |
112 if (header_size < kUpgradeHeaderLength) | |
113 return -1; | |
114 if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) { | |
115 mode_ = MODE_FAILED; | |
116 DLOG(INFO) << "Bad Upgrade Header " | |
117 << std::string(p, kUpgradeHeaderLength); | |
118 return p - data; | |
119 } | |
120 p += kUpgradeHeaderLength; | |
121 header_size = end - p; | |
122 if (header_size < kConnectionHeaderLength) | |
123 return -1; | |
124 if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) { | |
125 mode_ = MODE_FAILED; | |
126 DLOG(INFO) << "Bad Connection Header " | |
127 << std::string(p, kConnectionHeaderLength); | |
128 return p - data; | |
129 } | |
130 p += kConnectionHeaderLength; | |
131 } | |
132 | |
133 int eoh = HttpUtil::LocateEndOfHeaders(data, len); | |
134 if (eoh == -1) | |
135 return eoh; | |
136 | |
137 scoped_refptr<HttpResponseHeaders> headers( | |
138 new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); | |
139 if (!ProcessHeaders(*headers)) { | |
140 DLOG(INFO) << "Process Headers failed: " | |
141 << std::string(data, eoh); | |
142 mode_ = MODE_FAILED; | |
143 } | |
144 switch (mode_) { | |
145 case MODE_NORMAL: | |
146 if (CheckResponseHeaders()) { | |
147 mode_ = MODE_CONNECTED; | |
148 } else { | |
149 mode_ = MODE_FAILED; | |
150 } | |
151 break; | |
152 default: | |
153 mode_ = MODE_FAILED; | |
154 break; | |
155 } | |
156 return eoh; | |
157 } | |
158 | |
159 // Gets the value of the specified header. | |
160 // It assures only one header of |name| in |headers|. | |
161 // Returns true iff single header of |name| is found in |headers| | |
162 // and |value| is filled with the value. | |
163 // Returns false otherwise. | |
164 static bool GetSingleHeader(const HttpResponseHeaders& headers, | |
165 const std::string& name, | |
166 std::string* value) { | |
167 std::string first_value; | 158 std::string first_value; |
168 void* iter = NULL; | 159 void* iter = NULL; |
169 if (!headers.EnumerateHeader(&iter, name, &first_value)) | 160 if (!headers.EnumerateHeader(&iter, name, &first_value)) |
170 return false; | 161 return false; |
171 | 162 |
172 // Checks no more |name| found in |headers|. | 163 // Checks no more |name| found in |headers|. |
173 // Second call of EnumerateHeader() must return false. | 164 // Second call of EnumerateHeader() must return false. |
174 std::string second_value; | 165 std::string second_value; |
175 if (headers.EnumerateHeader(&iter, name, &second_value)) | 166 if (headers.EnumerateHeader(&iter, name, &second_value)) |
176 return false; | 167 return false; |
177 *value = first_value; | 168 *value = first_value; |
178 return true; | 169 return true; |
179 } | 170 } |
180 | 171 |
181 bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) { | 172 bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) { |
182 if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_)) | 173 std::string value; |
| 174 if (!GetSingleHeader(headers, "upgrade", &value) || |
| 175 value != "WebSocket") |
183 return false; | 176 return false; |
184 | 177 |
185 if (!GetSingleHeader(headers, "websocket-location", &ws_location_)) | 178 if (!GetSingleHeader(headers, "connection", &value) || |
| 179 !LowerCaseEqualsASCII(value, "upgrade")) |
| 180 return false; |
| 181 |
| 182 if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_)) |
| 183 return false; |
| 184 |
| 185 if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_)) |
186 return false; | 186 return false; |
187 | 187 |
188 // If |protocol_| is not specified by client, we don't care if there's | 188 // If |protocol_| is not specified by client, we don't care if there's |
189 // protocol field or not as specified in the spec. | 189 // protocol field or not as specified in the spec. |
190 if (!protocol_.empty() | 190 if (!protocol_.empty() |
191 && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_)) | 191 && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_)) |
192 return false; | 192 return false; |
193 return true; | 193 return true; |
194 } | 194 } |
195 | 195 |
196 bool WebSocketHandshake::CheckResponseHeaders() const { | 196 bool WebSocketHandshake::CheckResponseHeaders() const { |
197 DCHECK(mode_ == MODE_NORMAL); | 197 DCHECK(mode_ == MODE_NORMAL); |
198 if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str())) | 198 if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str())) |
199 return false; | 199 return false; |
200 if (location_ != ws_location_) | 200 if (location_ != ws_location_) |
201 return false; | 201 return false; |
202 if (!protocol_.empty() && protocol_ != ws_protocol_) | 202 if (!protocol_.empty() && protocol_ != ws_protocol_) |
203 return false; | 203 return false; |
204 return true; | 204 return true; |
205 } | 205 } |
206 | 206 |
| 207 namespace { |
207 | 208 |
| 209 // unsigned int version of base::RandInt(). |
| 210 // we can't use base::RandInt(), because max would be negative if it is |
| 211 // represented as int, so DCHECK(min <= max) fails. |
| 212 uint32 RandUint32(uint32 min, uint32 max) { |
| 213 DCHECK(min <= max); |
| 214 |
| 215 uint64 range = static_cast<int64>(max) - min + 1; |
| 216 uint64 number = base::RandUint64(); |
| 217 // TODO(ukai): fix to be uniform. |
| 218 // the distribution of the result of modulo will be biased. |
| 219 uint32 result = min + static_cast<uint32>(number % range); |
| 220 DCHECK(result >= min && result <= max); |
| 221 return result; |
| 222 } |
| 223 |
| 224 } |
| 225 |
| 226 uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) = |
| 227 RandUint32; |
| 228 uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39]; |
| 229 |
| 230 WebSocketHandshake::Parameter::Parameter() |
| 231 : number_1_(0), number_2_(0) { |
| 232 if (randomCharacterInSecWebSocketKey[0] == '\0') { |
| 233 int i = 0; |
| 234 for (int ch = 0x21; ch <= 0x2F; ch++, i++) |
| 235 randomCharacterInSecWebSocketKey[i] = ch; |
| 236 for (int ch = 0x3A; ch <= 0x7E; ch++, i++) |
| 237 randomCharacterInSecWebSocketKey[i] = ch; |
| 238 } |
| 239 } |
| 240 |
| 241 WebSocketHandshake::Parameter::~Parameter() {} |
| 242 |
| 243 void WebSocketHandshake::Parameter::GenerateKeys() { |
| 244 GenerateSecWebSocketKey(&number_1_, &key_1_); |
| 245 GenerateSecWebSocketKey(&number_2_, &key_2_); |
| 246 GenerateKey3(); |
| 247 } |
| 248 |
| 249 static void SetChallengeNumber(uint8* buf, uint32 number) { |
| 250 uint8* p = buf + 3; |
| 251 for (int i = 0; i < 4; i++) { |
| 252 *p = (uint8)(number & 0xFF); |
| 253 --p; |
| 254 number >>= 8; |
| 255 } |
| 256 } |
| 257 |
| 258 void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const { |
| 259 uint8 challenge[kExpectedResponseSize]; |
| 260 SetChallengeNumber(&challenge[0], number_1_); |
| 261 SetChallengeNumber(&challenge[4], number_2_); |
| 262 memcpy(&challenge[8], key_3_.data(), kKey3Size); |
| 263 MD5Digest digest; |
| 264 MD5Sum(challenge, kExpectedResponseSize, &digest); |
| 265 memcpy(expected, digest.a, kExpectedResponseSize); |
| 266 } |
| 267 |
| 268 /* static */ |
| 269 void WebSocketHandshake::Parameter::SetRandomNumberGenerator( |
| 270 uint32 (*rand)(uint32 min, uint32 max)) { |
| 271 rand_ = rand; |
| 272 } |
| 273 |
| 274 void WebSocketHandshake::Parameter::GenerateSecWebSocketKey( |
| 275 uint32* number, std::string* key) { |
| 276 uint32 space = rand_(1, 12); |
| 277 uint32 max = 4294967295U / space; |
| 278 *number = rand_(0, max); |
| 279 uint32 product = *number * space; |
| 280 |
| 281 std::string s = StringPrintf("%010u", product); |
| 282 for (uint32 i = 0; i < space; i++) { |
| 283 int pos = rand_(1, s.length() - 1); |
| 284 s = s.substr(0, pos) + " " + s.substr(pos); |
| 285 } |
| 286 int n = rand_(1, 12); |
| 287 for (int i = 0; i < n; i++) { |
| 288 int pos = rand_(0, s.length()); |
| 289 int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1); |
| 290 s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) + |
| 291 s.substr(pos); |
| 292 } |
| 293 *key = s; |
| 294 } |
| 295 |
| 296 void WebSocketHandshake::Parameter::GenerateKey3() { |
| 297 key_3_.clear(); |
| 298 for (int i = 0; i < 8; i++) { |
| 299 key_3_.append(1, rand_(0, 255)); |
| 300 } |
| 301 } |
208 | 302 |
209 } // namespace net | 303 } // namespace net |
OLD | NEW |