| Index: net/server/web_socket.cc
|
| diff --git a/net/server/web_socket.cc b/net/server/web_socket.cc
|
| index aa88016e72ab8fd3228f49957e3cec39eadb06d1..d4e16c8349e8fd7cce36795323dd22684b46c922 100644
|
| --- a/net/server/web_socket.cc
|
| +++ b/net/server/web_socket.cc
|
| @@ -4,6 +4,8 @@
|
|
|
| #include "net/server/web_socket.h"
|
|
|
| +#include <vector>
|
| +
|
| #include "base/base64.h"
|
| #include "base/logging.h"
|
| #include "base/sha1.h"
|
| @@ -15,71 +17,84 @@
|
| #include "net/server/http_server_request_info.h"
|
| #include "net/server/http_server_response_info.h"
|
| #include "net/server/web_socket_encoder.h"
|
| +#include "net/websockets/websocket_deflate_parameters.h"
|
| +#include "net/websockets/websocket_extension.h"
|
| +#include "net/websockets/websocket_handshake_constants.h"
|
|
|
| namespace net {
|
|
|
| -WebSocket::WebSocket(HttpServer* server,
|
| - HttpConnection* connection,
|
| - const HttpServerRequestInfo& request,
|
| - size_t* pos)
|
| - : server_(server), connection_(connection), closed_(false) {
|
| - std::string request_extensions =
|
| - request.GetHeaderValue("sec-websocket-extensions");
|
| - encoder_.reset(WebSocketEncoder::CreateServer(request_extensions,
|
| - &response_extensions_));
|
| - if (!response_extensions_.empty()) {
|
| - response_extensions_ =
|
| - "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n";
|
| - }
|
| +namespace {
|
| +
|
| +std::string ExtensionsHeaderString(
|
| + const std::vector<WebSocketExtension>& extensions) {
|
| + if (extensions.empty())
|
| + return std::string();
|
| +
|
| + std::string result = "Sec-WebSocket-Extensions: " + extensions[0].ToString();
|
| + for (size_t i = 1; i < extensions.size(); ++i)
|
| + result += ", " + extensions[i].ToString();
|
| + return result + "\r\n";
|
| }
|
|
|
| +std::string ValidResponseString(
|
| + const std::string& accept_hash,
|
| + const std::vector<WebSocketExtension> extensions) {
|
| + return base::StringPrintf(
|
| + "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
|
| + "Upgrade: WebSocket\r\n"
|
| + "Connection: Upgrade\r\n"
|
| + "Sec-WebSocket-Accept: %s\r\n"
|
| + "%s"
|
| + "\r\n",
|
| + accept_hash.c_str(), ExtensionsHeaderString(extensions).c_str());
|
| +}
|
| +
|
| +} // namespace
|
| +
|
| +WebSocket::WebSocket(HttpServer* server, HttpConnection* connection)
|
| + : server_(server), connection_(connection), closed_(false) {}
|
| +
|
| WebSocket::~WebSocket() {}
|
|
|
| -WebSocket* WebSocket::CreateWebSocket(HttpServer* server,
|
| - HttpConnection* connection,
|
| - const HttpServerRequestInfo& request,
|
| - size_t* pos) {
|
| +void WebSocket::Accept(const HttpServerRequestInfo& request) {
|
| std::string version = request.GetHeaderValue("sec-websocket-version");
|
| if (version != "8" && version != "13") {
|
| - server->SendResponse(
|
| - connection->id(),
|
| - HttpServerResponseInfo::CreateFor500(
|
| - "Invalid request format. The version is not valid."));
|
| - return nullptr;
|
| + SendErrorResponse("Invalid request format. The version is not valid.");
|
| + return;
|
| }
|
|
|
| std::string key = request.GetHeaderValue("sec-websocket-key");
|
| if (key.empty()) {
|
| - server->SendResponse(
|
| - connection->id(),
|
| - HttpServerResponseInfo::CreateFor500(
|
| - "Invalid request format. Sec-WebSocket-Key is empty or isn't "
|
| - "specified."));
|
| - return nullptr;
|
| + SendErrorResponse(
|
| + "Invalid request format. Sec-WebSocket-Key is empty or isn't "
|
| + "specified.");
|
| + return;
|
| }
|
| - return new WebSocket(server, connection, request, pos);
|
| -}
|
| -
|
| -void WebSocket::Accept(const HttpServerRequestInfo& request) {
|
| - static const char* const kWebSocketGuid =
|
| - "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
|
| - std::string key = request.GetHeaderValue("sec-websocket-key");
|
| - std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid);
|
| std::string encoded_hash;
|
| - base::Base64Encode(base::SHA1HashString(data), &encoded_hash);
|
| -
|
| - server_->SendRaw(
|
| - connection_->id(),
|
| - base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
|
| - "Upgrade: WebSocket\r\n"
|
| - "Connection: Upgrade\r\n"
|
| - "Sec-WebSocket-Accept: %s\r\n"
|
| - "%s"
|
| - "\r\n",
|
| - encoded_hash.c_str(), response_extensions_.c_str()));
|
| + base::Base64Encode(base::SHA1HashString(key + websockets::kWebSocketGuid),
|
| + &encoded_hash);
|
| +
|
| + std::vector<WebSocketExtension> response_extensions;
|
| + auto i = request.headers.find("sec-websocket-extensions");
|
| + if (i == request.headers.end()) {
|
| + encoder_ = WebSocketEncoder::CreateServer();
|
| + } else {
|
| + WebSocketDeflateParameters params;
|
| + encoder_ = WebSocketEncoder::CreateServer(i->second, ¶ms);
|
| + if (!encoder_) {
|
| + Fail();
|
| + return;
|
| + }
|
| + response_extensions.push_back(params.AsExtension());
|
| + }
|
| + server_->SendRaw(connection_->id(),
|
| + ValidResponseString(encoded_hash, response_extensions));
|
| }
|
|
|
| WebSocket::ParseResult WebSocket::Read(std::string* message) {
|
| + if (closed_)
|
| + return FRAME_CLOSE;
|
| +
|
| HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf();
|
| base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize());
|
| int bytes_consumed = 0;
|
| @@ -99,4 +114,17 @@ void WebSocket::Send(const std::string& message) {
|
| server_->SendRaw(connection_->id(), encoded);
|
| }
|
|
|
| +void WebSocket::Fail() {
|
| + closed_ = true;
|
| + // TODO(yhirano): The server SHOULD log the problem.
|
| + server_->Close(connection_->id());
|
| +}
|
| +
|
| +void WebSocket::SendErrorResponse(const std::string& message) {
|
| + if (closed_)
|
| + return;
|
| + closed_ = true;
|
| + server_->Send500(connection_->id(), message);
|
| +}
|
| +
|
| } // namespace net
|
|
|