OLD | NEW |
| (Empty) |
1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "net/websockets/websocket_handshake.h" | |
6 | |
7 #include <algorithm> | |
8 #include <vector> | |
9 | |
10 #include "base/logging.h" | |
11 #include "base/md5.h" | |
12 #include "base/memory/ref_counted.h" | |
13 #include "base/rand_util.h" | |
14 #include "base/string_number_conversions.h" | |
15 #include "base/string_util.h" | |
16 #include "base/stringprintf.h" | |
17 #include "net/http/http_response_headers.h" | |
18 #include "net/http/http_util.h" | |
19 | |
20 namespace net { | |
21 | |
22 const int WebSocketHandshake::kWebSocketPort = 80; | |
23 const int WebSocketHandshake::kSecureWebSocketPort = 443; | |
24 | |
25 WebSocketHandshake::WebSocketHandshake( | |
26 const GURL& url, | |
27 const std::string& origin, | |
28 const std::string& location, | |
29 const std::string& protocol) | |
30 : url_(url), | |
31 origin_(origin), | |
32 location_(location), | |
33 protocol_(protocol), | |
34 mode_(MODE_INCOMPLETE) { | |
35 } | |
36 | |
37 WebSocketHandshake::~WebSocketHandshake() { | |
38 } | |
39 | |
40 bool WebSocketHandshake::is_secure() const { | |
41 return url_.SchemeIs("wss"); | |
42 } | |
43 | |
44 std::string WebSocketHandshake::CreateClientHandshakeMessage() { | |
45 if (!parameter_.get()) { | |
46 parameter_.reset(new Parameter); | |
47 parameter_->GenerateKeys(); | |
48 } | |
49 std::string msg; | |
50 | |
51 // WebSocket protocol 4.1 Opening handshake. | |
52 | |
53 msg = "GET "; | |
54 msg += GetResourceName(); | |
55 msg += " HTTP/1.1\r\n"; | |
56 | |
57 std::vector<std::string> fields; | |
58 | |
59 fields.push_back("Upgrade: WebSocket"); | |
60 fields.push_back("Connection: Upgrade"); | |
61 | |
62 fields.push_back("Host: " + GetHostFieldValue()); | |
63 | |
64 fields.push_back("Origin: " + GetOriginFieldValue()); | |
65 | |
66 if (!protocol_.empty()) | |
67 fields.push_back("Sec-WebSocket-Protocol: " + protocol_); | |
68 | |
69 // TODO(ukai): Add cookie if necessary. | |
70 | |
71 fields.push_back("Sec-WebSocket-Key1: " + parameter_->GetSecWebSocketKey1()); | |
72 fields.push_back("Sec-WebSocket-Key2: " + parameter_->GetSecWebSocketKey2()); | |
73 | |
74 std::random_shuffle(fields.begin(), fields.end(), base::RandGenerator); | |
75 | |
76 for (size_t i = 0; i < fields.size(); i++) { | |
77 msg += fields[i] + "\r\n"; | |
78 } | |
79 msg += "\r\n"; | |
80 | |
81 msg.append(parameter_->GetKey3()); | |
82 return msg; | |
83 } | |
84 | |
85 int WebSocketHandshake::ReadServerHandshake(const char* data, size_t len) { | |
86 mode_ = MODE_INCOMPLETE; | |
87 int eoh = HttpUtil::LocateEndOfHeaders(data, len); | |
88 if (eoh < 0) | |
89 return -1; | |
90 | |
91 scoped_refptr<HttpResponseHeaders> headers( | |
92 new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(data, eoh))); | |
93 | |
94 if (headers->response_code() != 101) { | |
95 mode_ = MODE_FAILED; | |
96 DVLOG(1) << "Bad response code: " << headers->response_code(); | |
97 return eoh; | |
98 } | |
99 mode_ = MODE_NORMAL; | |
100 if (!ProcessHeaders(*headers) || !CheckResponseHeaders()) { | |
101 DVLOG(1) << "Process Headers failed: " << std::string(data, eoh); | |
102 mode_ = MODE_FAILED; | |
103 return eoh; | |
104 } | |
105 if (len < static_cast<size_t>(eoh + Parameter::kExpectedResponseSize)) { | |
106 mode_ = MODE_INCOMPLETE; | |
107 return -1; | |
108 } | |
109 uint8 expected[Parameter::kExpectedResponseSize]; | |
110 parameter_->GetExpectedResponse(expected); | |
111 if (memcmp(&data[eoh], expected, Parameter::kExpectedResponseSize)) { | |
112 mode_ = MODE_FAILED; | |
113 return eoh + Parameter::kExpectedResponseSize; | |
114 } | |
115 mode_ = MODE_CONNECTED; | |
116 return eoh + Parameter::kExpectedResponseSize; | |
117 } | |
118 | |
119 std::string WebSocketHandshake::GetResourceName() const { | |
120 std::string resource_name = url_.path(); | |
121 if (url_.has_query()) { | |
122 resource_name += "?"; | |
123 resource_name += url_.query(); | |
124 } | |
125 return resource_name; | |
126 } | |
127 | |
128 std::string WebSocketHandshake::GetHostFieldValue() const { | |
129 // url_.host() is expected to be encoded in punnycode here. | |
130 std::string host = StringToLowerASCII(url_.host()); | |
131 if (url_.has_port()) { | |
132 bool secure = is_secure(); | |
133 int port = url_.EffectiveIntPort(); | |
134 if ((!secure && | |
135 port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || | |
136 (secure && | |
137 port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { | |
138 host += ":"; | |
139 host += base::IntToString(port); | |
140 } | |
141 } | |
142 return host; | |
143 } | |
144 | |
145 std::string WebSocketHandshake::GetOriginFieldValue() const { | |
146 // It's OK to lowercase the origin as the Origin header does not contain | |
147 // the path or query portions, as per | |
148 // http://tools.ietf.org/html/draft-abarth-origin-00. | |
149 // | |
150 // TODO(satorux): Should we trim the port portion here if it's 80 for | |
151 // http:// or 443 for https:// ? Or can we assume it's done by the | |
152 // client of the library? | |
153 return StringToLowerASCII(origin_); | |
154 } | |
155 | |
156 /* static */ | |
157 bool WebSocketHandshake::GetSingleHeader(const HttpResponseHeaders& headers, | |
158 const std::string& name, | |
159 std::string* value) { | |
160 std::string first_value; | |
161 void* iter = NULL; | |
162 if (!headers.EnumerateHeader(&iter, name, &first_value)) | |
163 return false; | |
164 | |
165 // Checks no more |name| found in |headers|. | |
166 // Second call of EnumerateHeader() must return false. | |
167 std::string second_value; | |
168 if (headers.EnumerateHeader(&iter, name, &second_value)) | |
169 return false; | |
170 *value = first_value; | |
171 return true; | |
172 } | |
173 | |
174 bool WebSocketHandshake::ProcessHeaders(const HttpResponseHeaders& headers) { | |
175 std::string value; | |
176 if (!GetSingleHeader(headers, "upgrade", &value) || | |
177 value != "WebSocket") | |
178 return false; | |
179 | |
180 if (!GetSingleHeader(headers, "connection", &value) || | |
181 !LowerCaseEqualsASCII(value, "upgrade")) | |
182 return false; | |
183 | |
184 if (!GetSingleHeader(headers, "sec-websocket-origin", &ws_origin_)) | |
185 return false; | |
186 | |
187 if (!GetSingleHeader(headers, "sec-websocket-location", &ws_location_)) | |
188 return false; | |
189 | |
190 // If |protocol_| is not specified by client, we don't care if there's | |
191 // protocol field or not as specified in the spec. | |
192 if (!protocol_.empty() | |
193 && !GetSingleHeader(headers, "sec-websocket-protocol", &ws_protocol_)) | |
194 return false; | |
195 return true; | |
196 } | |
197 | |
198 bool WebSocketHandshake::CheckResponseHeaders() const { | |
199 DCHECK(mode_ == MODE_NORMAL); | |
200 if (!LowerCaseEqualsASCII(origin_, ws_origin_.c_str())) | |
201 return false; | |
202 if (location_ != ws_location_) | |
203 return false; | |
204 if (!protocol_.empty() && protocol_ != ws_protocol_) | |
205 return false; | |
206 return true; | |
207 } | |
208 | |
209 namespace { | |
210 | |
211 // unsigned int version of base::RandInt(). | |
212 // we can't use base::RandInt(), because max would be negative if it is | |
213 // represented as int, so DCHECK(min <= max) fails. | |
214 uint32 RandUint32(uint32 min, uint32 max) { | |
215 DCHECK(min <= max); | |
216 | |
217 uint64 range = static_cast<int64>(max) - min + 1; | |
218 uint64 number = base::RandGenerator(range); | |
219 uint32 result = min + static_cast<uint32>(number); | |
220 DCHECK(result >= min && result <= max); | |
221 return result; | |
222 } | |
223 | |
224 } | |
225 | |
226 uint32 (*WebSocketHandshake::Parameter::rand_)(uint32 min, uint32 max) = | |
227 RandUint32; | |
228 uint8 randomCharacterInSecWebSocketKey[0x2F - 0x20 + 0x7E - 0x39]; | |
229 | |
230 WebSocketHandshake::Parameter::Parameter() | |
231 : number_1_(0), number_2_(0) { | |
232 if (randomCharacterInSecWebSocketKey[0] == '\0') { | |
233 int i = 0; | |
234 for (int ch = 0x21; ch <= 0x2F; ch++, i++) | |
235 randomCharacterInSecWebSocketKey[i] = ch; | |
236 for (int ch = 0x3A; ch <= 0x7E; ch++, i++) | |
237 randomCharacterInSecWebSocketKey[i] = ch; | |
238 } | |
239 } | |
240 | |
241 WebSocketHandshake::Parameter::~Parameter() {} | |
242 | |
243 void WebSocketHandshake::Parameter::GenerateKeys() { | |
244 GenerateSecWebSocketKey(&number_1_, &key_1_); | |
245 GenerateSecWebSocketKey(&number_2_, &key_2_); | |
246 GenerateKey3(); | |
247 } | |
248 | |
249 static void SetChallengeNumber(uint8* buf, uint32 number) { | |
250 uint8* p = buf + 3; | |
251 for (int i = 0; i < 4; i++) { | |
252 *p = (uint8)(number & 0xFF); | |
253 --p; | |
254 number >>= 8; | |
255 } | |
256 } | |
257 | |
258 void WebSocketHandshake::Parameter::GetExpectedResponse(uint8 *expected) const { | |
259 uint8 challenge[kExpectedResponseSize]; | |
260 SetChallengeNumber(&challenge[0], number_1_); | |
261 SetChallengeNumber(&challenge[4], number_2_); | |
262 memcpy(&challenge[8], key_3_.data(), kKey3Size); | |
263 base::MD5Digest digest; | |
264 base::MD5Sum(challenge, kExpectedResponseSize, &digest); | |
265 memcpy(expected, digest.a, kExpectedResponseSize); | |
266 } | |
267 | |
268 /* static */ | |
269 void WebSocketHandshake::Parameter::SetRandomNumberGenerator( | |
270 uint32 (*rand)(uint32 min, uint32 max)) { | |
271 rand_ = rand; | |
272 } | |
273 | |
274 void WebSocketHandshake::Parameter::GenerateSecWebSocketKey( | |
275 uint32* number, std::string* key) { | |
276 uint32 space = rand_(1, 12); | |
277 uint32 max = 4294967295U / space; | |
278 *number = rand_(0, max); | |
279 uint32 product = *number * space; | |
280 | |
281 std::string s = base::StringPrintf("%u", product); | |
282 int n = rand_(1, 12); | |
283 for (int i = 0; i < n; i++) { | |
284 int pos = rand_(0, s.length()); | |
285 int chpos = rand_(0, sizeof(randomCharacterInSecWebSocketKey) - 1); | |
286 s = s.substr(0, pos).append(1, randomCharacterInSecWebSocketKey[chpos]) + | |
287 s.substr(pos); | |
288 } | |
289 for (uint32 i = 0; i < space; i++) { | |
290 int pos = rand_(1, s.length() - 1); | |
291 s = s.substr(0, pos) + " " + s.substr(pos); | |
292 } | |
293 *key = s; | |
294 } | |
295 | |
296 void WebSocketHandshake::Parameter::GenerateKey3() { | |
297 key_3_.clear(); | |
298 for (int i = 0; i < 8; i++) { | |
299 key_3_.append(1, rand_(0, 255)); | |
300 } | |
301 } | |
302 | |
303 } // namespace net | |
OLD | NEW |