Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(879)

Unified Diff: net/websockets/websocket_basic_handshake_stream.cc

Issue 143913003: Add construction of WebSocketDeflateStream (Closed) Base URL: http://git.chromium.org/chromium/src.git@master
Patch Set: Rebase and add DCHECK() for extension params. Created 6 years, 11 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
Index: net/websockets/websocket_basic_handshake_stream.cc
diff --git a/net/websockets/websocket_basic_handshake_stream.cc b/net/websockets/websocket_basic_handshake_stream.cc
index 3d2bcdfb742421ecfbfd71f592bf59a544dfe664..808383bacf40f99fac7fea6c3640cdb732398ca3 100644
--- a/net/websockets/websocket_basic_handshake_stream.cc
+++ b/net/websockets/websocket_basic_handshake_stream.cc
@@ -6,6 +6,7 @@
#include <algorithm>
#include <iterator>
+#include <set>
#include <string>
#include <vector>
@@ -14,6 +15,7 @@
#include "base/bind.h"
#include "base/containers/hash_tables.h"
#include "base/stl_util.h"
+#include "base/strings/string_number_conversions.h"
#include "base/strings/string_util.h"
#include "base/strings/stringprintf.h"
#include "base/time/time.h"
@@ -26,6 +28,10 @@
#include "net/http/http_stream_parser.h"
#include "net/socket/client_socket_handle.h"
#include "net/websockets/websocket_basic_stream.h"
+#include "net/websockets/websocket_deflate_predictor.h"
+#include "net/websockets/websocket_deflate_predictor_impl.h"
+#include "net/websockets/websocket_deflate_stream.h"
+#include "net/websockets/websocket_deflater.h"
#include "net/websockets/websocket_extension_parser.h"
#include "net/websockets/websocket_handshake_constants.h"
#include "net/websockets/websocket_handshake_handler.h"
@@ -34,6 +40,20 @@
#include "net/websockets/websocket_stream.h"
namespace net {
+
+// TODO(ricea): If more extensions are added, replace this with a more general
+// mechanism.
+struct WebSocketExtensionParams {
+ WebSocketExtensionParams()
+ : deflate_enabled(false),
+ client_window_bits(15),
+ deflate_mode(WebSocketDeflater::TAKE_OVER_CONTEXT) {}
+
+ bool deflate_enabled;
+ int client_window_bits;
+ WebSocketDeflater::ContextTakeOverMode deflate_mode;
+};
+
namespace {
enum GetHeaderResult {
@@ -202,12 +222,80 @@ bool ValidateSubProtocol(
return true;
}
+bool ValidatePerMessageDeflateExtension(const WebSocketExtension& extension,
+ std::string* failure_message,
+ WebSocketExtensionParams* params) {
+ static const char kClientPrefix[] = "client_";
+ static const char kServerPrefix[] = "server_";
+ static const char kNoContextTakeover[] = "no_context_takeover";
+ static const char kMaxWindowBits[] = "max_window_bits";
+ const size_t kPrefixLen = arraysize(kClientPrefix) - 1;
+ COMPILE_ASSERT(kPrefixLen == arraysize(kServerPrefix) - 1,
+ the_strings_server_and_client_must_be_the_same_length);
+ typedef std::vector<WebSocketExtension::Parameter> ParameterVector;
+
+ DCHECK(extension.name() == "permessage-deflate");
+ const ParameterVector& parameters = extension.parameters();
+ std::set<std::string> seen_names;
+ for (ParameterVector::const_iterator it = parameters.begin();
+ it != parameters.end(); ++it) {
+ const std::string& name = it->name();
+ if (seen_names.count(name) != 0) {
+ *failure_message =
+ "Received duplicate permessage-deflate extension parameter " + name;
+ return false;
+ }
+ seen_names.insert(name);
+ const std::string client_or_server(name, 0, kPrefixLen);
+ const bool is_client = (client_or_server == kClientPrefix);
+ if (!is_client && client_or_server != kServerPrefix) {
+ *failure_message =
+ "Received an unexpected permessage-deflate extension parameter";
+ return false;
+ }
+ const std::string rest(name, kPrefixLen);
+ if (rest == kNoContextTakeover) {
+ if (it->HasValue()) {
+ *failure_message = "Received invalid " + name + " parameter";
+ return false;
+ }
+ if (is_client)
+ params->deflate_mode = WebSocketDeflater::DO_NOT_TAKE_OVER_CONTEXT;
+ } else if (rest == kMaxWindowBits) {
+ if (!it->HasValue()) {
+ *failure_message = name + " must have value";
+ return false;
+ }
+ int bits = 0;
+ if (!base::StringToInt(it->value(), &bits) || bits < 8 || bits > 15 ||
+ it->value()[0] == '0' ||
+ it->value().find_first_not_of("0123456789") != std::string::npos) {
+ *failure_message = "Received invalid " + name + " parameter";
+ return false;
+ }
+ if (is_client)
+ params->client_window_bits = bits;
+ } else {
+ *failure_message =
+ "Received an unexpected permessage-deflate extension parameter";
+ return false;
+ }
+ }
+ params->deflate_enabled = true;
+ return true;
+}
+
bool ValidateExtensions(const HttpResponseHeaders* headers,
const std::vector<std::string>& requested_extensions,
std::string* extensions,
- std::string* failure_message) {
+ std::string* failure_message,
+ WebSocketExtensionParams* params) {
void* state = NULL;
std::string value;
+ std::vector<std::string> accepted_extensions;
+ // TODO(ricea): If adding support for additional extensions, generalise this
+ // code.
+ bool seen_permessage_deflate = false;
while (headers->EnumerateHeader(
&state, websockets::kSecWebSocketExtensions, &value)) {
WebSocketExtensionParser parser;
@@ -220,13 +308,25 @@ bool ValidateExtensions(const HttpResponseHeaders* headers,
value;
return false;
}
- // TODO(ricea): Accept permessage-deflate with valid parameters.
- *failure_message =
- "Found an unsupported extension '" +
- parser.extension().name() +
- "' in 'Sec-WebSocket-Extensions' header";
- return false;
+ if (parser.extension().name() == "permessage-deflate") {
+ if (seen_permessage_deflate) {
+ *failure_message = "Received duplicate permessage-deflate response";
+ return false;
+ }
+ seen_permessage_deflate = true;
+ if (!ValidatePerMessageDeflateExtension(
+ parser.extension(), failure_message, params))
+ return false;
+ } else {
+ *failure_message =
+ "Found an unsupported extension '" +
+ parser.extension().name() +
+ "' in 'Sec-WebSocket-Extensions' header";
+ return false;
+ }
+ accepted_extensions.push_back(value);
}
+ *extensions = JoinString(accepted_extensions, ", ");
return true;
}
@@ -284,12 +384,12 @@ int WebSocketBasicHandshakeStream::SendRequest(
}
enriched_headers.SetHeader(websockets::kSecWebSocketKey, handshake_challenge);
- AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
- requested_sub_protocols_,
- &enriched_headers);
AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketExtensions,
requested_extensions_,
&enriched_headers);
+ AddVectorHeaderIfNonEmpty(websockets::kSecWebSocketProtocol,
+ requested_sub_protocols_,
+ &enriched_headers);
ComputeSecWebSocketAccept(handshake_challenge,
&handshake_challenge_response_);
@@ -393,16 +493,25 @@ void WebSocketBasicHandshakeStream::SetPriority(RequestPriority priority) {
}
scoped_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() {
- // TODO(ricea): Add deflate support.
-
// The HttpStreamParser object has a pointer to our ClientSocketHandle. Make
// sure it does not touch it again before it is destroyed.
state_.DeleteParser();
- return scoped_ptr<WebSocketStream>(
+ scoped_ptr<WebSocketStream> basic_stream(
new WebSocketBasicStream(state_.ReleaseConnection(),
state_.read_buf(),
sub_protocol_,
extensions_));
+ DCHECK(extension_params_.get());
+ if (extension_params_->deflate_enabled) {
+ return scoped_ptr<WebSocketStream>(
+ new WebSocketDeflateStream(basic_stream.Pass(),
+ extension_params_->deflate_mode,
+ extension_params_->client_window_bits,
+ scoped_ptr<WebSocketDeflatePredictor>(
+ new WebSocketDeflatePredictorImpl)));
+ } else {
+ return basic_stream.Pass();
+ }
}
void WebSocketBasicHandshakeStream::SetWebSocketKeyForTesting(
@@ -464,6 +573,7 @@ int WebSocketBasicHandshakeStream::ValidateResponse() {
int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
const scoped_refptr<HttpResponseHeaders>& headers) {
+ extension_params_.reset(new WebSocketExtensionParams);
if (ValidateUpgrade(headers.get(), &failure_message_) &&
ValidateSecWebSocketAccept(headers.get(),
handshake_challenge_response_,
@@ -476,7 +586,8 @@ int WebSocketBasicHandshakeStream::ValidateUpgradeResponse(
ValidateExtensions(headers.get(),
requested_extensions_,
&extensions_,
- &failure_message_)) {
+ &failure_message_,
+ extension_params_.get())) {
return OK;
}
failure_message_ = "Error during WebSocket handshake: " + failure_message_;
« no previous file with comments | « net/websockets/websocket_basic_handshake_stream.h ('k') | net/websockets/websocket_handshake_stream_create_helper.cc » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698