Index: net/websockets/websocket_basic_handshake_stream.cc |
diff --git a/net/websockets/websocket_basic_handshake_stream.cc b/net/websockets/websocket_basic_handshake_stream.cc |
index da64386a463de3d54143d0eddd7b38ff2dd2c7c4..cdb432edf4f767b042d719a0c85deb6ad7987d84 100644 |
--- a/net/websockets/websocket_basic_handshake_stream.cc |
+++ b/net/websockets/websocket_basic_handshake_stream.cc |
@@ -35,6 +35,7 @@ |
#include "net/socket/client_socket_handle.h" |
#include "net/socket/websocket_transport_client_socket_pool.h" |
#include "net/websockets/websocket_basic_stream.h" |
+#include "net/websockets/websocket_deflate_parameters.h" |
#include "net/websockets/websocket_deflate_predictor.h" |
#include "net/websockets/websocket_deflate_predictor_impl.h" |
#include "net/websockets/websocket_deflate_stream.h" |
@@ -57,14 +58,8 @@ const char kConnectionErrorStatusLine[] = "HTTP/1.1 503 Connection Error"; |
// TODO(ricea): If more extensions are added, replace this with a more general |
// mechanism. |
struct WebSocketExtensionParams { |
- WebSocketExtensionParams() |
- : deflate_enabled(false), |
- client_window_bits(15), |
- deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {} |
- |
- bool deflate_enabled; |
- int client_window_bits; |
- WebSocketDeflater::ContextTakeOverMode deflate_mode; |
+ bool deflate_enabled = false; |
+ WebSocketDeflateParameters deflate_parameters; |
}; |
namespace { |
@@ -235,73 +230,6 @@ bool ValidateSubProtocol( |
return true; |
} |
-bool DeflateError(std::string* message, const base::StringPiece& piece) { |
- *message = "Error in permessage-deflate: "; |
- piece.AppendToString(message); |
- return false; |
-} |
- |
-bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, |
- std::string* failure_message, |
- WebSocketExtensionParams* params) { |
- static const char kClientPrefix[] = "client_"; |
- static const char kServerPrefix[] = "server_"; |
- static const char kNoContextTakeover[] = "no_context_takeover"; |
- static const char kMaxWindowBits[] = "max_window_bits"; |
- const size_t kPrefixLen = arraysize(kClientPrefix) - 1; |
- static_assert(kPrefixLen == arraysize(kServerPrefix) - 1, |
- "the strings server and client must be the same length"); |
- typedef std::vector<WebSocketExtension::Parameter> ParameterVector; |
- |
- DCHECK_EQ("permessage-deflate", extension.name()); |
- const ParameterVector& parameters = extension.parameters(); |
- std::set<std::string> seen_names; |
- for (ParameterVector::const_iterator it = parameters.begin(); |
- it != parameters.end(); ++it) { |
- const std::string& name = it->name(); |
- if (seen_names.count(name) != 0) { |
- return DeflateError( |
- failure_message, |
- "Received duplicate permessage-deflate extension parameter " + name); |
- } |
- seen_names.insert(name); |
- const std::string client_or_server(name, 0, kPrefixLen); |
- const bool is_client = (client_or_server == kClientPrefix); |
- if (!is_client && client_or_server != kServerPrefix) { |
- return DeflateError( |
- failure_message, |
- "Received an unexpected permessage-deflate extension parameter"); |
- } |
- const std::string rest(name, kPrefixLen); |
- if (rest == kNoContextTakeover) { |
- if (it->HasValue()) { |
- return DeflateError(failure_message, |
- "Received invalid " + name + " parameter"); |
- } |
- if (is_client) |
- params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; |
- } else if (rest == kMaxWindowBits) { |
- if (!it->HasValue()) |
- return DeflateError(failure_message, name + " must have value"); |
- int bits = 0; |
- if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || |
- it->value()[0] == '0' || |
- it->value().find_first_not_of("0123456789") != std::string::npos) { |
- return DeflateError(failure_message, |
- "Received invalid " + name + " parameter"); |
- } |
- if (is_client) |
- params->client_window_bits = bits; |
- } else { |
- return DeflateError( |
- failure_message, |
- "Received an unexpected permessage-deflate extension parameter"); |
- } |
- } |
- params->deflate_enabled = true; |
- return true; |
-} |
- |
bool ValidateExtensions(const HttpResponseHeaders* headers, |
std::string* accepted_extensions_descriptor, |
std::string* failure_message, |
@@ -332,11 +260,16 @@ bool ValidateExtensions(const HttpResponseHeaders* headers, |
return false; |
} |
seen_permessage_deflate = true; |
- |
- if (!ValidatePerMessageDeflateExtension(extension, failure_message, |
- params)) { |
+ auto& deflate_parameters = params->deflate_parameters; |
+ if (!deflate_parameters.Initialize(extension, failure_message) || |
+ !deflate_parameters.IsValidAsResponse(failure_message)) { |
+ *failure_message = "Error in permessage-deflate: " + *failure_message; |
return false; |
} |
+ // Note that we don't have to check the request-response compatibility |
+ // here because we send a request compatible with any valid responses. |
+ // TODO(yhirano): Place a DCHECK here. |
+ |
header_values.push_back(header_value); |
} else { |
*failure_message = "Found an unsupported extension '" + |
@@ -347,6 +280,7 @@ bool ValidateExtensions(const HttpResponseHeaders* headers, |
} |
} |
*accepted_extensions_descriptor = base::JoinString(header_values, ", "); |
+ params->deflate_enabled = seen_permessage_deflate; |
return true; |
} |
@@ -531,15 +465,13 @@ scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { |
if (extension_params_->deflate_enabled) { |
UMA_HISTOGRAM_ENUMERATION( |
"Net.WebSocket.DeflateMode", |
- extension_params_->deflate_mode, |
+ extension_params_->deflate_parameters.client_context_take_over_mode(), |
WebSocketDeflater::NUM_CONTEXT_TAKEOVER_MODE_TYPES); |
- return scoped_ptr<WebSocketStream>( |
- new WebSocketDeflateStream(basic_stream.Pass(), |
- extension_params_->deflate_mode, |
- extension_params_->client_window_bits, |
- scoped_ptr<WebSocketDeflatePredictor>( |
- new WebSocketDeflatePredictorImpl))); |
+ return scoped_ptr<WebSocketStream>(new WebSocketDeflateStream( |
+ basic_stream.Pass(), extension_params_->deflate_parameters, |
+ scoped_ptr<WebSocketDeflatePredictor>( |
+ new WebSocketDeflatePredictorImpl))); |
} else { |
return basic_stream.Pass(); |
} |