| Index: net/websockets/websocket_basic_handshake_stream.cc
|
| diff --git a/net/websockets/websocket_basic_handshake_stream.cc b/net/websockets/websocket_basic_handshake_stream.cc
|
| index 3d2bcdfb742421ecfbfd71f592bf59a544dfe664..808383bacf40f99fac7fea6c3640cdb732398ca3 100644
|
| --- a/net/websockets/websocket_basic_handshake_stream.cc
|
| +++ b/net/websockets/websocket_basic_handshake_stream.cc
|
| @@ -6,6 +6,7 @@
|
|
|
| #include <algorithm>
|
| #include <iterator>
|
| +#include <set>
|
| #include <string>
|
| #include <vector>
|
|
|
| @@ -14,6 +15,7 @@
|
| #include "base/bind.h"
|
| #include "base/containers/hash_tables.h"
|
| #include "base/stl_util.h"
|
| +#include "base/strings/string_number_conversions.h"
|
| #include "base/strings/string_util.h"
|
| #include "base/strings/stringprintf.h"
|
| #include "base/time/time.h"
|
| @@ -26,6 +28,10 @@
|
| #include "net/http/http_stream_parser.h"
|
| #include "net/socket/client_socket_handle.h"
|
| #include "net/websockets/websocket_basic_stream.h"
|
| +#include "net/websockets/websocket_deflate_predictor.h"
|
| +#include "net/websockets/websocket_deflate_predictor_impl.h"
|
| +#include "net/websockets/websocket_deflate_stream.h"
|
| +#include "net/websockets/websocket_deflater.h"
|
| #include "net/websockets/websocket_extension_parser.h"
|
| #include "net/websockets/websocket_handshake_constants.h"
|
| #include "net/websockets/websocket_handshake_handler.h"
|
| @@ -34,6 +40,20 @@
|
| #include "net/websockets/websocket_stream.h"
|
|
|
| namespace net {
|
| +
|
| +// TODO(ricea): If more extensions are added, replace this with a more general
|
| +// mechanism.
|
| +struct WebSocketExtensionParams {
|
| + WebSocketExtensionParams()
|
| + : deflate_enabled(false),
|
| + client_window_bits(15),
|
| + deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {}
|
| +
|
| + bool deflate_enabled;
|
| + int client_window_bits;
|
| + WebSocketDeflater::ContextTakeOverMode deflate_mode;
|
| +};
|
| +
|
| namespace {
|
|
|
| enum GetHeaderResult {
|
| @@ -202,12 +222,80 @@ bool ValidateSubProtocol(
|
| return true;
|
| }
|
|
|
| +bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension,
|
| + std::string* failure_message,
|
| + WebSocketExtensionParams* params) {
|
| + static const char kClientPrefix[] = "client_";
|
| + static const char kServerPrefix[] = "server_";
|
| + static const char kNoContextTakeover[] = "no_context_takeover";
|
| + static const char kMaxWindowBits[] = "max_window_bits";
|
| + const size_t kPrefixLen = arraysize(kClientPrefix) - 1;
|
| + COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1,
|
| + the_strings_server_and_client_must_be_the_same_length);
|
| + typedef std::vector<WebSocketExtension::Parameter> ParameterVector;
|
| +
|
| + DCHECK(extension.name() == "permessage-deflate");
|
| + const ParameterVector& parameters = extension.parameters();
|
| + std::set<std::string> seen_names;
|
| + for (ParameterVector::const_iterator it = parameters.begin();
|
| + it != parameters.end(); ++it) {
|
| + const std::string& name = it->name();
|
| + if (seen_names.count(name) != 0) {
|
| + *failure_message =
|
| + "Received duplicate permessage-deflate extension parameter " + name;
|
| + return false;
|
| + }
|
| + seen_names.insert(name);
|
| + const std::string client_or_server(name, 0, kPrefixLen);
|
| + const bool is_client = (client_or_server == kClientPrefix);
|
| + if (!is_client && client_or_server != kServerPrefix) {
|
| + *failure_message =
|
| + "Received an unexpected permessage-deflate extension parameter";
|
| + return false;
|
| + }
|
| + const std::string rest(name, kPrefixLen);
|
| + if (rest == kNoContextTakeover) {
|
| + if (it->HasValue()) {
|
| + *failure_message = "Received invalid " + name + " parameter";
|
| + return false;
|
| + }
|
| + if (is_client)
|
| + params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT;
|
| + } else if (rest == kMaxWindowBits) {
|
| + if (!it->HasValue()) {
|
| + *failure_message = name + " must have value";
|
| + return false;
|
| + }
|
| + int bits = 0;
|
| + if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 ||
|
| + it->value()[0] == '0' ||
|
| + it->value().find_first_not_of("0123456789") != std::string::npos) {
|
| + *failure_message = "Received invalid " + name + " parameter";
|
| + return false;
|
| + }
|
| + if (is_client)
|
| + params->client_window_bits = bits;
|
| + } else {
|
| + *failure_message =
|
| + "Received an unexpected permessage-deflate extension parameter";
|
| + return false;
|
| + }
|
| + }
|
| + params->deflate_enabled = true;
|
| + return true;
|
| +}
|
| +
|
| bool ValidateExtensions(const HttpResponseHeaders* headers,
|
| const std::vector<std::string>& requested_extensions,
|
| std::string* extensions,
|
| - std::string* failure_message) {
|
| + std::string* failure_message,
|
| + WebSocketExtensionParams* params) {
|
| void* state = NULL;
|
| std::string value;
|
| + std::vector<std::string> accepted_extensions;
|
| + // TODO(ricea): If adding support for additional extensions, generalise this
|
| + // code.
|
| + bool seen_permessage_deflate = false;
|
| while (headers->EnumerateHeader(
|
| &state, websockets::kSecWebSocketExtensions, &value)) {
|
| WebSocketExtensionParser parser;
|
| @@ -220,13 +308,25 @@ bool ValidateExtensions(const HttpResponseHeaders* headers,
|
| value;
|
| return false;
|
| }
|
| - // TODO(ricea): Accept permessage-deflate with valid parameters.
|
| - *failure_message =
|
| - "Found an unsupported extension '" +
|
| - parser.extension().name() +
|
| - "' in 'Sec-WebSocket-Extensions' header";
|
| - return false;
|
| + if (parser.extension().name() == "permessage-deflate") {
|
| + if (seen_permessage_deflate) {
|
| + *failure_message = "Received duplicate permessage-deflate response";
|
| + return false;
|
| + }
|
| + seen_permessage_deflate = true;
|
| + if (!ValidatePerMessageDeflateExtension(
|
| + parser.extension(), failure_message, params))
|
| + return false;
|
| + } else {
|
| + *failure_message =
|
| + "Found an unsupported extension '" +
|
| + parser.extension().name() +
|
| + "' in 'Sec-WebSocket-Extensions' header";
|
| + return false;
|
| + }
|
| + accepted_extensions.push_back(value);
|
| }
|
| + *extensions = JoinString(accepted_extensions, ", ");
|
| return true;
|
| }
|
|
|
| @@ -284,12 +384,12 @@ int WebSocketBasicHandshakeStream::SendRequest(
|
| }
|
| enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
|
|
|
| - AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
|
| - requested_sub_protocols_,
|
| - &enriched_headers);
|
| AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
|
| requested_extensions_,
|
| &enriched_headers);
|
| + AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
|
| + requested_sub_protocols_,
|
| + &enriched_headers);
|
|
|
| ComputeSecWebSocketAccept(handshake_challenge,
|
| &handshake_challenge_response_);
|
| @@ -393,16 +493,25 @@ void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
|
| }
|
|
|
| scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
|
| - // TODO(ricea): Add deflate support.
|
| -
|
| // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
|
| // sure it does not touch it again before it is destroyed.
|
| state_.DeleteParser();
|
| - return scoped_ptr<WebSocketStream>(
|
| + scoped_ptr<WebSocketStream> basic_stream(
|
| new WebSocketBasicStream(state_.ReleaseConnection(),
|
| state_.read_buf(),
|
| sub_protocol_,
|
| extensions_));
|
| + DCHECK(extension_params_.get());
|
| + if (extension_params_->deflate_enabled) {
|
| + return scoped_ptr<WebSocketStream>(
|
| + new WebSocketDeflateStream(basic_stream.Pass(),
|
| + extension_params_->deflate_mode,
|
| + extension_params_->client_window_bits,
|
| + scoped_ptr<WebSocketDeflatePredictor>(
|
| + new WebSocketDeflatePredictorImpl)));
|
| + } else {
|
| + return basic_stream.Pass();
|
| + }
|
| }
|
|
|
| void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
|
| @@ -464,6 +573,7 @@ int WebSocketBasicHandshakeStream::ValidateResponse() {
|
|
|
| int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
|
| const scoped_refptr<HttpResponseHeaders>& headers) {
|
| + extension_params_.reset(new WebSocketExtensionParams);
|
| if (ValidateUpgrade(headers.get(), &failure_message_) &&
|
| ValidateSecWebSocketAccept(headers.get(),
|
| handshake_challenge_response_,
|
| @@ -476,7 +586,8 @@ int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
|
| ValidateExtensions(headers.get(),
|
| requested_extensions_,
|
| &extensions_,
|
| - &failure_message_)) {
|
| + &failure_message_,
|
| + extension_params_.get())) {
|
| return OK;
|
| }
|
| failure_message_ = "Error during WebSocket handshake: " + failure_message_;
|
|
|