OLD | NEW |
1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/websockets/websocket_basic_handshake_stream.h" | 5 #include "net/websockets/websocket_basic_handshake_stream.h" |
6 | 6 |
7 #include <algorithm> | 7 #include <algorithm> |
8 #include <iterator> | 8 #include <iterator> |
9 #include <set> | 9 #include <set> |
10 #include <string> | 10 #include <string> |
(...skipping 17 matching lines...) Expand all Loading... |
28 #include "net/base/io_buffer.h" | 28 #include "net/base/io_buffer.h" |
29 #include "net/http/http_request_headers.h" | 29 #include "net/http/http_request_headers.h" |
30 #include "net/http/http_request_info.h" | 30 #include "net/http/http_request_info.h" |
31 #include "net/http/http_response_body_drainer.h" | 31 #include "net/http/http_response_body_drainer.h" |
32 #include "net/http/http_response_headers.h" | 32 #include "net/http/http_response_headers.h" |
33 #include "net/http/http_status_code.h" | 33 #include "net/http/http_status_code.h" |
34 #include "net/http/http_stream_parser.h" | 34 #include "net/http/http_stream_parser.h" |
35 #include "net/socket/client_socket_handle.h" | 35 #include "net/socket/client_socket_handle.h" |
36 #include "net/socket/websocket_transport_client_socket_pool.h" | 36 #include "net/socket/websocket_transport_client_socket_pool.h" |
37 #include "net/websockets/websocket_basic_stream.h" | 37 #include "net/websockets/websocket_basic_stream.h" |
| 38 #include "net/websockets/websocket_deflate_parameters.h" |
38 #include "net/websockets/websocket_deflate_predictor.h" | 39 #include "net/websockets/websocket_deflate_predictor.h" |
39 #include "net/websockets/websocket_deflate_predictor_impl.h" | 40 #include "net/websockets/websocket_deflate_predictor_impl.h" |
40 #include "net/websockets/websocket_deflate_stream.h" | 41 #include "net/websockets/websocket_deflate_stream.h" |
41 #include "net/websockets/websocket_deflater.h" | 42 #include "net/websockets/websocket_deflater.h" |
42 #include "net/websockets/websocket_extension_parser.h" | 43 #include "net/websockets/websocket_extension_parser.h" |
43 #include "net/websockets/websocket_handshake_challenge.h" | 44 #include "net/websockets/websocket_handshake_challenge.h" |
44 #include "net/websockets/websocket_handshake_constants.h" | 45 #include "net/websockets/websocket_handshake_constants.h" |
45 #include "net/websockets/websocket_handshake_request_info.h" | 46 #include "net/websockets/websocket_handshake_request_info.h" |
46 #include "net/websockets/websocket_handshake_response_info.h" | 47 #include "net/websockets/websocket_handshake_response_info.h" |
47 #include "net/websockets/websocket_stream.h" | 48 #include "net/websockets/websocket_stream.h" |
48 | 49 |
49 namespace net { | 50 namespace net { |
50 | 51 |
51 namespace { | 52 namespace { |
52 | 53 |
53 const char kConnectionErrorStatusLine[] = "HTTP/1.1 503 Connection Error"; | 54 const char kConnectionErrorStatusLine[] = "HTTP/1.1 503 Connection Error"; |
54 | 55 |
55 } // namespace | 56 } // namespace |
56 | 57 |
57 // TODO(ricea): If more extensions are added, replace this with a more general | 58 // TODO(ricea): If more extensions are added, replace this with a more general |
58 // mechanism. | 59 // mechanism. |
59 struct WebSocketExtensionParams { | 60 struct WebSocketExtensionParams { |
60 WebSocketExtensionParams() | 61 bool deflate_enabled = false; |
61 : deflate_enabled(false), | 62 WebSocketDeflateParameters deflate_parameters; |
62 client_window_bits(15), | |
63 deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {} | |
64 | |
65 bool deflate_enabled; | |
66 int client_window_bits; | |
67 WebSocketDeflater::ContextTakeOverMode deflate_mode; | |
68 }; | 63 }; |
69 | 64 |
70 namespace { | 65 namespace { |
71 | 66 |
72 enum GetHeaderResult { | 67 enum GetHeaderResult { |
73 GET_HEADER_OK, | 68 GET_HEADER_OK, |
74 GET_HEADER_MISSING, | 69 GET_HEADER_MISSING, |
75 GET_HEADER_MULTIPLE, | 70 GET_HEADER_MULTIPLE, |
76 }; | 71 }; |
77 | 72 |
(...skipping 150 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
228 } else if (requested_sub_protocols.size() > 0 && count == 0) { | 223 } else if (requested_sub_protocols.size() > 0 && count == 0) { |
229 *failure_message = | 224 *failure_message = |
230 "Sent non-empty 'Sec-WebSocket-Protocol' header " | 225 "Sent non-empty 'Sec-WebSocket-Protocol' header " |
231 "but no response was received"; | 226 "but no response was received"; |
232 return false; | 227 return false; |
233 } | 228 } |
234 *sub_protocol = value; | 229 *sub_protocol = value; |
235 return true; | 230 return true; |
236 } | 231 } |
237 | 232 |
238 bool DeflateError(std::string* message, const base::StringPiece& piece) { | |
239 *message = "Error in permessage-deflate: "; | |
240 piece.AppendToString(message); | |
241 return false; | |
242 } | |
243 | |
244 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, | |
245 std::string* failure_message, | |
246 WebSocketExtensionParams* params) { | |
247 static const char kClientPrefix[] = "client_"; | |
248 static const char kServerPrefix[] = "server_"; | |
249 static const char kNoContextTakeover[] = "no_context_takeover"; | |
250 static const char kMaxWindowBits[] = "max_window_bits"; | |
251 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; | |
252 static_assert(kPrefixLen == arraysize(kServerPrefix) - 1, | |
253 "the strings server and client must be the same length"); | |
254 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; | |
255 | |
256 DCHECK_EQ("permessage-deflate", extension.name()); | |
257 const ParameterVector& parameters = extension.parameters(); | |
258 std::set<std::string> seen_names; | |
259 for (ParameterVector::const_iterator it = parameters.begin(); | |
260 it != parameters.end(); ++it) { | |
261 const std::string& name = it->name(); | |
262 if (seen_names.count(name) != 0) { | |
263 return DeflateError( | |
264 failure_message, | |
265 "Received duplicate permessage-deflate extension parameter " + name); | |
266 } | |
267 seen_names.insert(name); | |
268 const std::string client_or_server(name, 0, kPrefixLen); | |
269 const bool is_client = (client_or_server == kClientPrefix); | |
270 if (!is_client && client_or_server != kServerPrefix) { | |
271 return DeflateError( | |
272 failure_message, | |
273 "Received an unexpected permessage-deflate extension parameter"); | |
274 } | |
275 const std::string rest(name, kPrefixLen); | |
276 if (rest == kNoContextTakeover) { | |
277 if (it->HasValue()) { | |
278 return DeflateError(failure_message, | |
279 "Received invalid " + name + " parameter"); | |
280 } | |
281 if (is_client) | |
282 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; | |
283 } else if (rest == kMaxWindowBits) { | |
284 if (!it->HasValue()) | |
285 return DeflateError(failure_message, name + " must have value"); | |
286 int bits = 0; | |
287 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || | |
288 it->value()[0] == '0' || | |
289 it->value().find_first_not_of("0123456789") != std::string::npos) { | |
290 return DeflateError(failure_message, | |
291 "Received invalid " + name + " parameter"); | |
292 } | |
293 if (is_client) | |
294 params->client_window_bits = bits; | |
295 } else { | |
296 return DeflateError( | |
297 failure_message, | |
298 "Received an unexpected permessage-deflate extension parameter"); | |
299 } | |
300 } | |
301 params->deflate_enabled = true; | |
302 return true; | |
303 } | |
304 | |
305 bool ValidateExtensions(const HttpResponseHeaders* headers, | 233 bool ValidateExtensions(const HttpResponseHeaders* headers, |
306 std::string* accepted_extensions_descriptor, | 234 std::string* accepted_extensions_descriptor, |
307 std::string* failure_message, | 235 std::string* failure_message, |
308 WebSocketExtensionParams* params) { | 236 WebSocketExtensionParams* params) { |
309 void* state = nullptr; | 237 void* state = nullptr; |
310 std::string header_value; | 238 std::string header_value; |
311 std::vector<std::string> header_values; | 239 std::vector<std::string> header_values; |
312 // TODO(ricea): If adding support for additional extensions, generalise this | 240 // TODO(ricea): If adding support for additional extensions, generalise this |
313 // code. | 241 // code. |
314 bool seen_permessage_deflate = false; | 242 bool seen_permessage_deflate = false; |
(...skipping 10 matching lines...) Expand all Loading... |
325 } | 253 } |
326 | 254 |
327 const std::vector<WebSocketExtension>& extensions = parser.extensions(); | 255 const std::vector<WebSocketExtension>& extensions = parser.extensions(); |
328 for (const auto& extension : extensions) { | 256 for (const auto& extension : extensions) { |
329 if (extension.name() == "permessage-deflate") { | 257 if (extension.name() == "permessage-deflate") { |
330 if (seen_permessage_deflate) { | 258 if (seen_permessage_deflate) { |
331 *failure_message = "Received duplicate permessage-deflate response"; | 259 *failure_message = "Received duplicate permessage-deflate response"; |
332 return false; | 260 return false; |
333 } | 261 } |
334 seen_permessage_deflate = true; | 262 seen_permessage_deflate = true; |
335 | 263 auto& deflate_parameters = params->deflate_parameters; |
336 if (!ValidatePerMessageDeflateExtension(extension, failure_message, | 264 if (!deflate_parameters.Initialize(extension, failure_message) || |
337 params)) { | 265 !deflate_parameters.IsValidAsResponse(failure_message)) { |
| 266 *failure_message = "Error in permessage-deflate: " + *failure_message; |
338 return false; | 267 return false; |
339 } | 268 } |
| 269 // Note that we don't have to check the request-response compatibility |
| 270 // here because we send a request compatible with any valid responses. |
| 271 // TODO(yhirano): Place a DCHECK here. |
| 272 |
340 header_values.push_back(header_value); | 273 header_values.push_back(header_value); |
341 } else { | 274 } else { |
342 *failure_message = "Found an unsupported extension '" + | 275 *failure_message = "Found an unsupported extension '" + |
343 extension.name() + | 276 extension.name() + |
344 "' in 'Sec-WebSocket-Extensions' header"; | 277 "' in 'Sec-WebSocket-Extensions' header"; |
345 return false; | 278 return false; |
346 } | 279 } |
347 } | 280 } |
348 } | 281 } |
349 *accepted_extensions_descriptor = base::JoinString(header_values, ", "); | 282 *accepted_extensions_descriptor = base::JoinString(header_values, ", "); |
| 283 params->deflate_enabled = seen_permessage_deflate; |
350 return true; | 284 return true; |
351 } | 285 } |
352 | 286 |
353 } // namespace | 287 } // namespace |
354 | 288 |
355 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( | 289 WebSocketBasicHandshakeStream::WebSocketBasicHandshakeStream( |
356 scoped_ptr<ClientSocketHandle> connection, | 290 scoped_ptr<ClientSocketHandle> connection, |
357 WebSocketStream::ConnectDelegate* connect_delegate, | 291 WebSocketStream::ConnectDelegate* connect_delegate, |
358 bool using_proxy, | 292 bool using_proxy, |
359 std::vector<std::string> requested_sub_protocols, | 293 std::vector<std::string> requested_sub_protocols, |
(...skipping 164 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
524 WebSocketTransportClientSocketPool::UnlockEndpoint(state_.connection()); | 458 WebSocketTransportClientSocketPool::UnlockEndpoint(state_.connection()); |
525 scoped_ptr<WebSocketStream> basic_stream( | 459 scoped_ptr<WebSocketStream> basic_stream( |
526 new WebSocketBasicStream(state_.ReleaseConnection(), | 460 new WebSocketBasicStream(state_.ReleaseConnection(), |
527 state_.read_buf(), | 461 state_.read_buf(), |
528 sub_protocol_, | 462 sub_protocol_, |
529 extensions_)); | 463 extensions_)); |
530 DCHECK(extension_params_.get()); | 464 DCHECK(extension_params_.get()); |
531 if (extension_params_->deflate_enabled) { | 465 if (extension_params_->deflate_enabled) { |
532 UMA_HISTOGRAM_ENUMERATION( | 466 UMA_HISTOGRAM_ENUMERATION( |
533 "Net.WebSocket.DeflateMode", | 467 "Net.WebSocket.DeflateMode", |
534 extension_params_->deflate_mode, | 468 extension_params_->deflate_parameters.client_context_take_over_mode(), |
535 WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES); | 469 WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES); |
536 | 470 |
537 return scoped_ptr<WebSocketStream>( | 471 return scoped_ptr<WebSocketStream>(new WebSocketDeflateStream( |
538 new WebSocketDeflateStream(basic_stream.Pass(), | 472 basic_stream.Pass(), extension_params_->deflate_parameters, |
539 extension_params_->deflate_mode, | 473 scoped_ptr<WebSocketDeflatePredictor>( |
540 extension_params_->client_window_bits, | 474 new WebSocketDeflatePredictorImpl))); |
541 scoped_ptr<WebSocketDeflatePredictor>( | |
542 new WebSocketDeflatePredictorImpl))); | |
543 } else { | 475 } else { |
544 return basic_stream.Pass(); | 476 return basic_stream.Pass(); |
545 } | 477 } |
546 } | 478 } |
547 | 479 |
548 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( | 480 void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( |
549 const std::string& key) { | 481 const std::string& key) { |
550 handshake_challenge_for_testing_.reset(new std::string(key)); | 482 handshake_challenge_for_testing_.reset(new std::string(key)); |
551 } | 483 } |
552 | 484 |
(...skipping 92 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
645 set_failure_message("Error during WebSocket handshake: " + failure_message); | 577 set_failure_message("Error during WebSocket handshake: " + failure_message); |
646 return ERR_INVALID_RESPONSE; | 578 return ERR_INVALID_RESPONSE; |
647 } | 579 } |
648 | 580 |
649 void WebSocketBasicHandshakeStream::set_failure_message( | 581 void WebSocketBasicHandshakeStream::set_failure_message( |
650 const std::string& failure_message) { | 582 const std::string& failure_message) { |
651 *failure_message_ = failure_message; | 583 *failure_message_ = failure_message; |
652 } | 584 } |
653 | 585 |
654 } // namespace net | 586 } // namespace net |
OLD | NEW |