Chromium Code Reviews| Index: net/websockets/websocket_basic_handshake_stream.cc |
| diff --git a/net/websockets/websocket_basic_handshake_stream.cc b/net/websockets/websocket_basic_handshake_stream.cc |
| index 3d2bcdfb742421ecfbfd71f592bf59a544dfe664..9f0b0dcddc347ada04abe6411b7ecee2e5a23497 100644 |
| --- a/net/websockets/websocket_basic_handshake_stream.cc |
| +++ b/net/websockets/websocket_basic_handshake_stream.cc |
| @@ -6,6 +6,7 @@ |
| #include <algorithm> |
| #include <iterator> |
| +#include <set> |
| #include <string> |
| #include <vector> |
| @@ -14,6 +15,7 @@ |
| #include "base/bind.h" |
| #include "base/containers/hash_tables.h" |
| #include "base/stl_util.h" |
| +#include "base/strings/string_number_conversions.h" |
| #include "base/strings/string_util.h" |
| #include "base/strings/stringprintf.h" |
| #include "base/time/time.h" |
| @@ -26,6 +28,10 @@ |
| #include "net/http/http_stream_parser.h" |
| #include "net/socket/client_socket_handle.h" |
| #include "net/websockets/websocket_basic_stream.h" |
| +#include "net/websockets/websocket_deflate_predictor.h" |
| +#include "net/websockets/websocket_deflate_predictor_impl.h" |
| +#include "net/websockets/websocket_deflate_stream.h" |
| +#include "net/websockets/websocket_deflater.h" |
| #include "net/websockets/websocket_extension_parser.h" |
| #include "net/websockets/websocket_handshake_constants.h" |
| #include "net/websockets/websocket_handshake_handler.h" |
| @@ -34,6 +40,20 @@ |
| #include "net/websockets/websocket_stream.h" |
| namespace net { |
| + |
| +// TODO(ricea): If more extensions are added, replace this with a more general |
| +// mechanism. |
| +struct WebSocketExtensionParams { |
| + WebSocketExtensionParams() |
| + : deflate_enabled(false), |
| + client_window_bits(15), |
| + deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {} |
| + |
| + bool deflate_enabled; |
| + int client_window_bits; |
| + WebSocketDeflater::ContextTakeOverMode deflate_mode; |
| +}; |
| + |
| namespace { |
| enum GetHeaderResult { |
| @@ -202,12 +222,80 @@ bool ValidateSubProtocol( |
| return true; |
| } |
| +bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension, |
| + std::string* failure_message, |
| + WebSocketExtensionParams* params) { |
| + static const char kClientPrefix[] = "client_"; |
| + static const char kServerPrefix[] = "server_"; |
| + static const char kNoContextTakeover[] = "no_context_takeover"; |
| + static const char kMaxWindowBits[] = "max_window_bits"; |
| + const size_t kPrefixLen = arraysize(kClientPrefix) - 1; |
| + COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1, |
| + the_strings_server_and_client_must_be_the_same_length); |
| + typedef std::vector<WebSocketExtension::Parameter> ParameterVector; |
| + |
| + DCHECK(extension.name() == "permessage-deflate"); |
| + const ParameterVector& parameters = extension.parameters(); |
| + std::set<std::string> seen_names; |
| + for (ParameterVector::const_iterator it = parameters.begin(); |
| + it != parameters.end(); ++it) { |
| + const std::string& name = it->name(); |
| + if (seen_names.count(name) != 0) { |
| + *failure_message = |
| + "Received duplicate permessage-deflate extension parameter " + name; |
| + return false; |
| + } |
| + seen_names.insert(name); |
| + const std::string client_or_server(name, 0, kPrefixLen); |
| + const bool is_client = (client_or_server == kClientPrefix); |
| + if (!is_client && client_or_server != kServerPrefix) { |
| + *failure_message = |
| + "Received an unexpected permessage-deflate extension parameter"; |
| + return false; |
| + } |
| + const std::string rest(name, kPrefixLen); |
| + if (rest == kNoContextTakeover) { |
| + if (it->HasValue()) { |
| + *failure_message = "Received invalid " + name + " parameter"; |
| + return false; |
| + } |
| + if (is_client) |
| + params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT; |
| + } else if (rest == kMaxWindowBits) { |
| + if (!it->HasValue()) { |
| + *failure_message = name + " must have value"; |
| + return false; |
| + } |
| + int bits = 0; |
| + if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 || |
| + it->value()[0] == '0' || |
| + it->value().find_first_not_of("0123456789") != std::string::npos) { |
| + *failure_message = "Received invalid " + name + " parameter"; |
| + return false; |
| + } |
| + if (is_client) |
| + params->client_window_bits = bits; |
| + } else { |
| + *failure_message = |
| + "Received an unexpected permessage-deflate extension parameter"; |
| + return false; |
| + } |
| + } |
| + params->deflate_enabled = true; |
| + return true; |
| +} |
| + |
| bool ValidateExtensions(const HttpResponseHeaders* headers, |
| const std::vector<std::string>& requested_extensions, |
| std::string* extensions, |
| - std::string* failure_message) { |
| + std::string* failure_message, |
| + WebSocketExtensionParams* params) { |
| void* state = NULL; |
| std::string value; |
| + std::vector<std::string> accepted_extensions; |
| + // TODO(ricea): If adding support for additional extensions, generalise this |
| + // code. |
| + bool seen_permessage_deflate = false; |
| while (headers->EnumerateHeader( |
| &state, websockets::kSecWebSocketExtensions, &value)) { |
| WebSocketExtensionParser parser; |
| @@ -220,13 +308,25 @@ bool ValidateExtensions(const HttpResponseHeaders* headers, |
| value; |
| return false; |
| } |
| - // TODO(ricea): Accept permessage-deflate with valid parameters. |
| - *failure_message = |
| - "Found an unsupported extension '" + |
| - parser.extension().name() + |
| - "' in 'Sec-WebSocket-Extensions' header"; |
| - return false; |
| + if (parser.extension().name() == "permessage-deflate") { |
| + if (seen_permessage_deflate) { |
| + *failure_message = "Received duplicate permessage-deflate response"; |
| + return false; |
| + } |
| + seen_permessage_deflate = true; |
| + if (!ValidatePerMessageDeflateExtension( |
| + parser.extension(), failure_message, params)) |
| + return false; |
| + } else { |
| + *failure_message = |
| + "Found an unsupported extension '" + |
| + parser.extension().name() + |
| + "' in 'Sec-WebSocket-Extensions' header"; |
| + return false; |
| + } |
| + accepted_extensions.push_back(value); |
| } |
| + *extensions = JoinString(accepted_extensions, ", "); |
| return true; |
| } |
| @@ -284,12 +384,12 @@ int WebSocketBasicHandshakeStream::SendRequest( |
| } |
| enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge); |
| - AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, |
| - requested_sub_protocols_, |
| - &enriched_headers); |
| AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions, |
| requested_extensions_, |
| &enriched_headers); |
| + AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol, |
| + requested_sub_protocols_, |
| + &enriched_headers); |
| ComputeSecWebSocketAccept(handshake_challenge, |
| &handshake_challenge_response_); |
| @@ -393,16 +493,25 @@ void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) { |
| } |
| scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { |
| - // TODO(ricea): Add deflate support. |
| - |
| // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make |
|
tyoshino (SeeGerritForStatus)
2014/01/27 16:23:24
DCHECK(extension_params_.get())
Adam Rice
2014/01/28 08:07:52
Done.
|
| // sure it does not touch it again before it is destroyed. |
| state_.DeleteParser(); |
| - return scoped_ptr<WebSocketStream>( |
| + scoped_ptr<WebSocketStream> basic_stream( |
| new WebSocketBasicStream(state_.ReleaseConnection(), |
| state_.read_buf(), |
| sub_protocol_, |
| extensions_)); |
| + |
| + if (extension_params_->deflate_enabled) { |
| + return scoped_ptr<WebSocketStream>( |
| + new WebSocketDeflateStream(basic_stream.Pass(), |
| + extension_params_->deflate_mode, |
| + extension_params_->client_window_bits, |
| + scoped_ptr<WebSocketDeflatePredictor>( |
| + new WebSocketDeflatePredictorImpl))); |
| + } else { |
| + return basic_stream.Pass(); |
| + } |
| } |
| void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting( |
| @@ -464,6 +573,7 @@ int WebSocketBasicHandshakeStream::ValidateResponse() { |
| int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( |
| const scoped_refptr<HttpResponseHeaders>& headers) { |
| + extension_params_.reset(new WebSocketExtensionParams); |
| if (ValidateUpgrade(headers.get(), &failure_message_) && |
| ValidateSecWebSocketAccept(headers.get(), |
| handshake_challenge_response_, |
| @@ -476,7 +586,8 @@ int WebSocketBasicHandshakeStream::ValidateUpgradeResponse( |
| ValidateExtensions(headers.get(), |
| requested_extensions_, |
| &extensions_, |
| - &failure_message_)) { |
| + &failure_message_, |
| + extension_params_.get())) { |
| return OK; |
| } |
| failure_message_ = "Error during WebSocket handshake: " + failure_message_; |