OLD | NEW |
1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/websockets/websocket_basic_handshake_stream.h" | 5 #include "net/websockets/websocket_basic_handshake_stream.h" |
6 | 6 |
7 #include <algorithm> | 7 #include <algorithm> |
8 #include <iterator> | 8 #include <iterator> |
9 #include <set> | 9 #include <set> |
10 #include <string> | 10 #include <string> |
(...skipping 51 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
62 GET_HEADER_OK, | 62 GET_HEADER_OK, |
63 GET_HEADER_MISSING, | 63 GET_HEADER_MISSING, |
64 GET_HEADER_MULTIPLE, | 64 GET_HEADER_MULTIPLE, |
65 }; | 65 }; |
66 | 66 |
67 std::string MissingHeaderMessage(const std::string& header_name) { | 67 std::string MissingHeaderMessage(const std::string& header_name) { |
68 return std::string("'") + header_name + "' header is missing"; | 68 return std::string("'") + header_name + "' header is missing"; |
69 } | 69 } |
70 | 70 |
71 std::string MultipleHeaderValuesMessage(const std::string& header_name) { | 71 std::string MultipleHeaderValuesMessage(const std::string& header_name) { |
72 return | 72 return std::string("'") + header_name + |
73 std::string("'") + | 73 "' header must not appear more than once in a response"; |
74 header_name + | |
75 "' header must not appear more than once in a response"; | |
76 } | 74 } |
77 | 75 |
78 std::string GenerateHandshakeChallenge() { | 76 std::string GenerateHandshakeChallenge() { |
79 std::string raw_challenge(websockets::kRawChallengeLength, '\0'); | 77 std::string raw_challenge(websockets::kRawChallengeLength, '\0'); |
80 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); | 78 crypto::RandBytes(string_as_array(&raw_challenge), raw_challenge.length()); |
81 std::string encoded_challenge; | 79 std::string encoded_challenge; |
82 base::Base64Encode(raw_challenge, &encoded_challenge); | 80 base::Base64Encode(raw_challenge, &encoded_challenge); |
83 return encoded_challenge; | 81 return encoded_challenge; |
84 } | 82 } |
85 | 83 |
(...skipping 32 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
118 } | 116 } |
119 DCHECK_EQ(result, GET_HEADER_OK); | 117 DCHECK_EQ(result, GET_HEADER_OK); |
120 return true; | 118 return true; |
121 } | 119 } |
122 | 120 |
123 bool ValidateUpgrade(const HttpResponseHeaders* headers, | 121 bool ValidateUpgrade(const HttpResponseHeaders* headers, |
124 std::string* failure_message) { | 122 std::string* failure_message) { |
125 std::string value; | 123 std::string value; |
126 GetHeaderResult result = | 124 GetHeaderResult result = |
127 GetSingleHeaderValue(headers, websockets::kUpgrade, &value); | 125 GetSingleHeaderValue(headers, websockets::kUpgrade, &value); |
128 if (!ValidateHeaderHasSingleValue(result, | 126 if (!ValidateHeaderHasSingleValue( |
129 websockets::kUpgrade, | 127 result, websockets::kUpgrade, failure_message)) { |
130 failure_message)) { | |
131 return false; | 128 return false; |
132 } | 129 } |
133 | 130 |
134 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) { | 131 if (!LowerCaseEqualsASCII(value, websockets::kWebSocketLowercase)) { |
135 *failure_message = | 132 *failure_message = "'Upgrade' header value is not 'WebSocket': " + value; |
136 "'Upgrade' header value is not 'WebSocket': " + value; | |
137 return false; | 133 return false; |
138 } | 134 } |
139 return true; | 135 return true; |
140 } | 136 } |
141 | 137 |
142 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers, | 138 bool ValidateSecWebSocketAccept(const HttpResponseHeaders* headers, |
143 const std::string& expected, | 139 const std::string& expected, |
144 std::string* failure_message) { | 140 std::string* failure_message) { |
145 std::string actual; | 141 std::string actual; |
146 GetHeaderResult result = | 142 GetHeaderResult result = |
147 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual); | 143 GetSingleHeaderValue(headers, websockets::kSecWebSocketAccept, &actual); |
148 if (!ValidateHeaderHasSingleValue(result, | 144 if (!ValidateHeaderHasSingleValue( |
149 websockets::kSecWebSocketAccept, | 145 result, websockets::kSecWebSocketAccept, failure_message)) { |
150 failure_message)) { | |
151 return false; | 146 return false; |
152 } | 147 } |
153 | 148 |
154 if (expected != actual) { | 149 if (expected != actual) { |
155 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value"; | 150 *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value"; |
156 return false; | 151 return false; |
157 } | 152 } |
158 return true; | 153 return true; |
159 } | 154 } |
160 | 155 |
(...skipping 35 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
196 has_invalid_protocol = true; | 191 has_invalid_protocol = true; |
197 if (++count > 1) | 192 if (++count > 1) |
198 has_multiple_protocols = true; | 193 has_multiple_protocols = true; |
199 } | 194 } |
200 | 195 |
201 if (has_multiple_protocols) { | 196 if (has_multiple_protocols) { |
202 *failure_message = | 197 *failure_message = |
203 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol); | 198 MultipleHeaderValuesMessage(websockets::kSecWebSocketProtocol); |
204 return false; | 199 return false; |
205 } else if (count > 0 && requested_sub_protocols.size() == 0) { | 200 } else if (count > 0 && requested_sub_protocols.size() == 0) { |
206 *failure_message = | 201 *failure_message = std::string( |
207 std::string("Response must not include 'Sec-WebSocket-Protocol' " | 202 "Response must not include 'Sec-WebSocket-Protocol' " |
208 "header if not present in request: ") | 203 "header if not present in request: ") + |
209 + value; | 204 value; |
210 return false; | 205 return false; |
211 } else if (has_invalid_protocol) { | 206 } else if (has_invalid_protocol) { |
212 *failure_message = | 207 *failure_message = "'Sec-WebSocket-Protocol' header value '" + value + |
213 "'Sec-WebSocket-Protocol' header value '" + | 208 "' in response does not match any of sent values"; |
214 value + | |
215 "' in response does not match any of sent values"; | |
216 return false; | 209 return false; |
217 } else if (requested_sub_protocols.size() > 0 && count == 0) { | 210 } else if (requested_sub_protocols.size() > 0 && count == 0) { |
218 *failure_message = | 211 *failure_message = |
219 "Sent non-empty 'Sec-WebSocket-Protocol' header " | 212 "Sent non-empty 'Sec-WebSocket-Protocol' header " |
220 "but no response was received"; | 213 "but no response was received"; |
221 return false; | 214 return false; |
222 } | 215 } |
223 *sub_protocol = value; | 216 *sub_protocol = value; |
224 return true; | 217 return true; |
225 } | 218 } |
(...skipping 13 matching lines...) Expand all Loading... |
239 static const char kMaxWindowBits[] = "max_window_bits"; | 232 static const char kMaxWindowBits[] = "max_window_bits"; |
240 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; | 233 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; |
241 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, | 234 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, |
242 the_strings_server_and_client_must_be_the_same_length); | 235 the_strings_server_and_client_must_be_the_same_length); |
243 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; | 236 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; |
244 | 237 |
245 DCHECK_EQ("permessage-deflate", extension.name()); | 238 DCHECK_EQ("permessage-deflate", extension.name()); |
246 const ParameterVector& parameters = extension.parameters(); | 239 const ParameterVector& parameters = extension.parameters(); |
247 std::set<std::string> seen_names; | 240 std::set<std::string> seen_names; |
248 for (ParameterVector::const_iterator it = parameters.begin(); | 241 for (ParameterVector::const_iterator it = parameters.begin(); |
249 it != parameters.end(); ++it) { | 242 it != parameters.end(); |
| 243 ++it) { |
250 const std::string& name = it->name(); | 244 const std::string& name = it->name(); |
251 if (seen_names.count(name) != 0) { | 245 if (seen_names.count(name) != 0) { |
252 return DeflateError( | 246 return DeflateError( |
253 failure_message, | 247 failure_message, |
254 "Received duplicate permessage-deflate extension parameter " + name); | 248 "Received duplicate permessage-deflate extension parameter " + name); |
255 } | 249 } |
256 seen_names.insert(name); | 250 seen_names.insert(name); |
257 const std::string client_or_server(name, 0, kPrefixLen); | 251 const std::string client_or_server(name, 0, kPrefixLen); |
258 const bool is_client = (client_or_server == kClientPrefix); | 252 const bool is_client = (client_or_server == kClientPrefix); |
259 if (!is_client && client_or_server != kServerPrefix) { | 253 if (!is_client && client_or_server != kServerPrefix) { |
(...skipping 36 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
296 std::string* extensions, | 290 std::string* extensions, |
297 std::string* failure_message, | 291 std::string* failure_message, |
298 WebSocketExtensionParams* params) { | 292 WebSocketExtensionParams* params) { |
299 void* state = NULL; | 293 void* state = NULL; |
300 std::string value; | 294 std::string value; |
301 std::vector<std::string> accepted_extensions; | 295 std::vector<std::string> accepted_extensions; |
302 // TODO(ricea): If adding support for additional extensions, generalise this | 296 // TODO(ricea): If adding support for additional extensions, generalise this |
303 // code. | 297 // code. |
304 bool seen_permessage_deflate = false; | 298 bool seen_permessage_deflate = false; |
305 while (headers->EnumerateHeader( | 299 while (headers->EnumerateHeader( |
306 &state, websockets::kSecWebSocketExtensions, &value)) { | 300 &state, websockets::kSecWebSocketExtensions, &value)) { |
307 WebSocketExtensionParser parser; | 301 WebSocketExtensionParser parser; |
308 parser.Parse(value); | 302 parser.Parse(value); |
309 if (parser.has_error()) { | 303 if (parser.has_error()) { |
310 // TODO(yhirano) Set appropriate failure message. | 304 // TODO(yhirano) Set appropriate failure message. |
311 *failure_message = | 305 *failure_message = |
312 "'Sec-WebSocket-Extensions' header value is " | 306 "'Sec-WebSocket-Extensions' header value is " |
313 "rejected by the parser: " + | 307 "rejected by the parser: " + |
314 value; | 308 value; |
315 return false; | 309 return false; |
316 } | 310 } |
317 if (parser.extension().name() == "permessage-deflate") { | 311 if (parser.extension().name() == "permessage-deflate") { |
318 if (seen_permessage_deflate) { | 312 if (seen_permessage_deflate) { |
319 *failure_message = "Received duplicate permessage-deflate response"; | 313 *failure_message = "Received duplicate permessage-deflate response"; |
320 return false; | 314 return false; |
321 } | 315 } |
322 seen_permessage_deflate = true; | 316 seen_permessage_deflate = true; |
323 if (!ValidatePerMessageDeflateExtension( | 317 if (!ValidatePerMessageDeflateExtension( |
324 parser.extension(), failure_message, params)) | 318 parser.extension(), failure_message, params)) |
325 return false; | 319 return false; |
326 } else { | 320 } else { |
327 *failure_message = | 321 *failure_message = "Found an unsupported extension '" + |
328 "Found an unsupported extension '" + | 322 parser.extension().name() + |
329 parser.extension().name() + | 323 "' in 'Sec-WebSocket-Extensions' header"; |
330 "' in 'Sec-WebSocket-Extensions' header"; | |
331 return false; | 324 return false; |
332 } | 325 } |
333 accepted_extensions.push_back(value); | 326 accepted_extensions.push_back(value); |
334 } | 327 } |
335 *extensions = JoinString(accepted_extensions, ", "); | 328 *extensions = JoinString(accepted_extensions, ", "); |
336 return true; | 329 return true; |
337 } | 330 } |
338 | 331 |
339 } // namespace | 332 } // namespace |
340 | 333 |
341 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( | 334 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( |
342 scoped_ptr<ClientSocketHandle> connection, | 335 scoped_ptr<ClientSocketHandle> connection, |
343 WebSocketStream::ConnectDelegate* connect_delegate, | 336 WebSocketStream::ConnectDelegate* connect_delegate, |
344 bool using_proxy, | 337 bool using_proxy, |
345 std::vector<std::string> requested_sub_protocols, | 338 std::vector<std::string> requested_sub_protocols, |
346 std::vector<std::string> requested_extensions) | 339 std::vector<std::string> requested_extensions) |
347 : state_(connection.release(), using_proxy), | 340 : state_(connection.release(), using_proxy), |
348 connect_delegate_(connect_delegate), | 341 connect_delegate_(connect_delegate), |
349 http_response_info_(NULL), | 342 http_response_info_(NULL), |
350 requested_sub_protocols_(requested_sub_protocols), | 343 requested_sub_protocols_(requested_sub_protocols), |
351 requested_extensions_(requested_extensions) {} | 344 requested_extensions_(requested_extensions) { |
| 345 } |
352 | 346 |
353 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() {} | 347 WebSocketBasicHandshakeStream::~WebSocketBasicHandshakeStream() { |
| 348 } |
354 | 349 |
355 int WebSocketBasicHandshakeStream::InitializeStream( | 350 int WebSocketBasicHandshakeStream::InitializeStream( |
356 const HttpRequestInfo* request_info, | 351 const HttpRequestInfo* request_info, |
357 RequestPriority priority, | 352 RequestPriority priority, |
358 const BoundNetLog& net_log, | 353 const BoundNetLog& net_log, |
359 const CompletionCallback& callback) { | 354 const CompletionCallback& callback) { |
360 url_ = request_info->url; | 355 url_ = request_info->url; |
361 state_.Initialize(request_info, priority, net_log, callback); | 356 state_.Initialize(request_info, priority, net_log, callback); |
362 return OK; | 357 return OK; |
363 } | 358 } |
(...skipping 111 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
475 | 470 |
476 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) { | 471 void WebSocketBasicHandshakeStream::GetSSLInfo(SSLInfo* ssl_info) { |
477 parser()->GetSSLInfo(ssl_info); | 472 parser()->GetSSLInfo(ssl_info); |
478 } | 473 } |
479 | 474 |
480 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo( | 475 void WebSocketBasicHandshakeStream::GetSSLCertRequestInfo( |
481 SSLCertRequestInfo* cert_request_info) { | 476 SSLCertRequestInfo* cert_request_info) { |
482 parser()->GetSSLCertRequestInfo(cert_request_info); | 477 parser()->GetSSLCertRequestInfo(cert_request_info); |
483 } | 478 } |
484 | 479 |
485 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { return false; } | 480 bool WebSocketBasicHandshakeStream::IsSpdyHttpStream() const { |
| 481 return false; |
| 482 } |
486 | 483 |
487 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) { | 484 void WebSocketBasicHandshakeStream::Drain(HttpNetworkSession* session) { |
488 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this); | 485 HttpResponseBodyDrainer* drainer = new HttpResponseBodyDrainer(this); |
489 drainer->Start(session); | 486 drainer->Start(session); |
490 // |drainer| will delete itself. | 487 // |drainer| will delete itself. |
491 } | 488 } |
492 | 489 |
493 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { | 490 void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { |
494 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is | 491 // TODO(ricea): See TODO comment in HttpBasicStream::SetPriority(). If it is |
495 // gone, then copy whatever has happened there over here. | 492 // gone, then copy whatever has happened there over here. |
(...skipping 100 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
596 std::string("Error during WebSocket handshake: ") + ErrorToString(rv); | 593 std::string("Error during WebSocket handshake: ") + ErrorToString(rv); |
597 OnFinishOpeningHandshake(); | 594 OnFinishOpeningHandshake(); |
598 return rv; | 595 return rv; |
599 } | 596 } |
600 } | 597 } |
601 | 598 |
602 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( | 599 int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( |
603 const HttpResponseHeaders* headers) { | 600 const HttpResponseHeaders* headers) { |
604 extension_params_.reset(new WebSocketExtensionParams); | 601 extension_params_.reset(new WebSocketExtensionParams); |
605 if (ValidateUpgrade(headers, &failure_message_) && | 602 if (ValidateUpgrade(headers, &failure_message_) && |
606 ValidateSecWebSocketAccept(headers, | 603 ValidateSecWebSocketAccept( |
607 handshake_challenge_response_, | 604 headers, handshake_challenge_response_, &failure_message_) && |
608 &failure_message_) && | |
609 ValidateConnection(headers, &failure_message_) && | 605 ValidateConnection(headers, &failure_message_) && |
610 ValidateSubProtocol(headers, | 606 ValidateSubProtocol(headers, |
611 requested_sub_protocols_, | 607 requested_sub_protocols_, |
612 &sub_protocol_, | 608 &sub_protocol_, |
613 &failure_message_) && | 609 &failure_message_) && |
614 ValidateExtensions(headers, | 610 ValidateExtensions(headers, |
615 requested_extensions_, | 611 requested_extensions_, |
616 &extensions_, | 612 &extensions_, |
617 &failure_message_, | 613 &failure_message_, |
618 extension_params_.get())) { | 614 extension_params_.get())) { |
619 return OK; | 615 return OK; |
620 } | 616 } |
621 failure_message_ = "Error during WebSocket handshake: " + failure_message_; | 617 failure_message_ = "Error during WebSocket handshake: " + failure_message_; |
622 return ERR_INVALID_RESPONSE; | 618 return ERR_INVALID_RESPONSE; |
623 } | 619 } |
624 | 620 |
625 } // namespace net | 621 } // namespace net |
OLD | NEW |