OLD | NEW |
1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/websockets/websocket_basic_handshake_stream.h" | 5 #include "net/websockets/websocket_basic_handshake_stream.h" |
6 | 6 |
7 #include <algorithm> | 7 #include <algorithm> |
8 #include <iterator> | 8 #include <iterator> |
9 #include <set> | 9 #include <set> |
10 #include <string> | 10 #include <string> |
11 #include <vector> | 11 #include <vector> |
12 | 12 |
13 #include "base/base64.h" | 13 #include "base/base64.h" |
14 #include "base/basictypes.h" | 14 #include "base/basictypes.h" |
15 #include "base/bind.h" | 15 #include "base/bind.h" |
16 #include "base/containers/hash_tables.h" | 16 #include "base/containers/hash_tables.h" |
17 #include "base/stl_util.h" | 17 #include "base/stl_util.h" |
18 #include "base/strings/string_number_conversions.h" | 18 #include "base/strings/string_number_conversions.h" |
| 19 #include "base/strings/string_piece.h" |
19 #include "base/strings/string_util.h" | 20 #include "base/strings/string_util.h" |
20 #include "base/strings/stringprintf.h" | 21 #include "base/strings/stringprintf.h" |
21 #include "base/time/time.h" | 22 #include "base/time/time.h" |
22 #include "crypto/random.h" | 23 #include "crypto/random.h" |
23 #include "net/http/http_request_headers.h" | 24 #include "net/http/http_request_headers.h" |
24 #include "net/http/http_request_info.h" | 25 #include "net/http/http_request_info.h" |
25 #include "net/http/http_response_body_drainer.h" | 26 #include "net/http/http_response_body_drainer.h" |
26 #include "net/http/http_response_headers.h" | 27 #include "net/http/http_response_headers.h" |
27 #include "net/http/http_status_code.h" | 28 #include "net/http/http_status_code.h" |
28 #include "net/http/http_stream_parser.h" | 29 #include "net/http/http_stream_parser.h" |
(...skipping 186 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
215 } else if (requested_sub_protocols.size() > 0 && count == 0) { | 216 } else if (requested_sub_protocols.size() > 0 && count == 0) { |
216 *failure_message = | 217 *failure_message = |
217 "Sent non-empty 'Sec-WebSocket-Protocol' header " | 218 "Sent non-empty 'Sec-WebSocket-Protocol' header " |
218 "but no response was received"; | 219 "but no response was received"; |
219 return false; | 220 return false; |
220 } | 221 } |
221 *sub_protocol = value; | 222 *sub_protocol = value; |
222 return true; | 223 return true; |
223 } | 224 } |
224 | 225 |
| 226 bool DeflateError(std::string* message, const base::StringPiece& piece) { |
| 227 *message = "Error in permessage-deflate: "; |
| 228 AppendToString(piece, message); |
| 229 return false; |
| 230 } |
| 231 |
225 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, | 232 bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, |
226 std::string* failure_message, | 233 std::string* failure_message, |
227 WebSocketExtensionParams* params) { | 234 WebSocketExtensionParams* params) { |
228 static const char kClientPrefix[] = "client_"; | 235 static const char kClientPrefix[] = "client_"; |
229 static const char kServerPrefix[] = "server_"; | 236 static const char kServerPrefix[] = "server_"; |
230 static const char kNoContextTakeover[] = "no_context_takeover"; | 237 static const char kNoContextTakeover[] = "no_context_takeover"; |
231 static const char kMaxWindowBits[] = "max_window_bits"; | 238 static const char kMaxWindowBits[] = "max_window_bits"; |
232 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; | 239 const size_t kPrefixLen = arraysize(kClientPrefix) - 1; |
233 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, | 240 COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, |
234 the_strings_server_and_client_must_be_the_same_length); | 241 the_strings_server_and_client_must_be_the_same_length); |
235 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; | 242 typedef std::vector<WebSocketExtension::Parameter> ParameterVector; |
236 | 243 |
237 DCHECK_EQ("permessage-deflate", extension.name()); | 244 DCHECK_EQ("permessage-deflate", extension.name()); |
238 const ParameterVector& parameters = extension.parameters(); | 245 const ParameterVector& parameters = extension.parameters(); |
239 std::set<std::string> seen_names; | 246 std::set<std::string> seen_names; |
240 for (ParameterVector::const_iterator it = parameters.begin(); | 247 for (ParameterVector::const_iterator it = parameters.begin(); |
241 it != parameters.end(); ++it) { | 248 it != parameters.end(); ++it) { |
242 const std::string& name = it->name(); | 249 const std::string& name = it->name(); |
243 if (seen_names.count(name) != 0) { | 250 if (seen_names.count(name) != 0) { |
244 *failure_message = | 251 return DeflateError( |
245 "Received duplicate permessage-deflate extension parameter " + name; | 252 failure_message, |
246 return false; | 253 "Received duplicate permessage-deflate extension parameter " + name); |
247 } | 254 } |
248 seen_names.insert(name); | 255 seen_names.insert(name); |
249 const std::string client_or_server(name, 0, kPrefixLen); | 256 const std::string client_or_server(name, 0, kPrefixLen); |
250 const bool is_client = (client_or_server == kClientPrefix); | 257 const bool is_client = (client_or_server == kClientPrefix); |
251 if (!is_client && client_or_server != kServerPrefix) { | 258 if (!is_client && client_or_server != kServerPrefix) { |
252 *failure_message = | 259 return DeflateError( |
253 "Received an unexpected permessage-deflate extension parameter"; | 260 failure_message, |
254 return false; | 261 "Received an unexpected permessage-deflate extension parameter"); |
255 } | 262 } |
256 const std::string rest(name, kPrefixLen); | 263 const std::string rest(name, kPrefixLen); |
257 if (rest == kNoContextTakeover) { | 264 if (rest == kNoContextTakeover) { |
258 if (it->HasValue()) { | 265 if (it->HasValue()) { |
259 *failure_message = "Received invalid " + name + " parameter"; | 266 return DeflateError(failure_message, |
260 return false; | 267 "Received invalid " + name + " parameter"); |
261 } | 268 } |
262 if (is_client) | 269 if (is_client) |
263 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; | 270 params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; |
264 } else if (rest == kMaxWindowBits) { | 271 } else if (rest == kMaxWindowBits) { |
265 if (!it->HasValue()) { | 272 if (!it->HasValue()) |
266 *failure_message = name + " must have value"; | 273 return DeflateError(failure_message, name + " must have value"); |
267 return false; | |
268 } | |
269 int bits = 0; | 274 int bits = 0; |
270 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || | 275 if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || |
271 it->value()[0] == '0' || | 276 it->value()[0] == '0' || |
272 it->value().find_first_not_of("0123456789") != std::string::npos) { | 277 it->value().find_first_not_of("0123456789") != std::string::npos) { |
273 *failure_message = "Received invalid " + name + " parameter"; | 278 return DeflateError(failure_message, |
274 return false; | 279 "Received invalid " + name + " parameter"); |
275 } | 280 } |
276 if (is_client) | 281 if (is_client) |
277 params->client_window_bits = bits; | 282 params->client_window_bits = bits; |
278 } else { | 283 } else { |
279 *failure_message = | 284 return DeflateError( |
280 "Received an unexpected permessage-deflate extension parameter"; | 285 failure_message, |
281 return false; | 286 "Received an unexpected permessage-deflate extension parameter"); |
282 } | 287 } |
283 } | 288 } |
284 params->deflate_enabled = true; | 289 params->deflate_enabled = true; |
285 return true; | 290 return true; |
286 } | 291 } |
287 | 292 |
288 bool ValidateExtensions(const HttpResponseHeaders* headers, | 293 bool ValidateExtensions(const HttpResponseHeaders* headers, |
289 const std::vector<std::string>& requested_extensions, | 294 const std::vector<std::string>& requested_extensions, |
290 std::string* extensions, | 295 std::string* extensions, |
291 std::string* failure_message, | 296 std::string* failure_message, |
(...skipping 304 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
596 &extensions_, | 601 &extensions_, |
597 &failure_message_, | 602 &failure_message_, |
598 extension_params_.get())) { | 603 extension_params_.get())) { |
599 return OK; | 604 return OK; |
600 } | 605 } |
601 failure_message_ = "Error during WebSocket handshake: " + failure_message_; | 606 failure_message_ = "Error during WebSocket handshake: " + failure_message_; |
602 return ERR_INVALID_RESPONSE; | 607 return ERR_INVALID_RESPONSE; |
603 } | 608 } |
604 | 609 |
605 } // namespace net | 610 } // namespace net |
OLD | NEW |