Index: net/server/web_socket.cc |
diff --git a/net/server/web_socket.cc b/net/server/web_socket.cc |
index 2ddb03c2477f2b0d550dd9464d46d88ff40356f8..aa88016e72ab8fd3228f49957e3cec39eadb06d1 100644 |
--- a/net/server/web_socket.cc |
+++ b/net/server/web_socket.cc |
@@ -6,7 +6,6 @@ |
#include "base/base64.h" |
#include "base/logging.h" |
-#include "base/md5.h" |
#include "base/sha1.h" |
#include "base/strings/string_number_conversions.h" |
#include "base/strings/stringprintf.h" |
@@ -19,241 +18,85 @@ |
namespace net { |
-namespace { |
- |
-static uint32 WebSocketKeyFingerprint(const std::string& str) { |
- std::string result; |
- const char* p_char = str.c_str(); |
- int length = str.length(); |
- int spaces = 0; |
- for (int i = 0; i < length; ++i) { |
- if (p_char[i] >= '0' && p_char[i] <= '9') |
- result.append(&p_char[i], 1); |
- else if (p_char[i] == ' ') |
- spaces++; |
+WebSocket::WebSocket(HttpServer* server, |
+ HttpConnection* connection, |
+ const HttpServerRequestInfo& request, |
+ size_t* pos) |
+ : server_(server), connection_(connection), closed_(false) { |
+ std::string request_extensions = |
+ request.GetHeaderValue("sec-websocket-extensions"); |
+ encoder_.reset(WebSocketEncoder::CreateServer(request_extensions, |
+ &response_extensions_)); |
+ if (!response_extensions_.empty()) { |
+ response_extensions_ = |
+ "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n"; |
} |
- if (spaces == 0) |
- return 0; |
- int64 number = 0; |
- if (!base::StringToInt64(result, &number)) |
- return 0; |
- return base::HostToNet32(static_cast<uint32>(number / spaces)); |
} |
-class WebSocketHixie76 : public WebSocket { |
- public: |
- static WebSocket* Create(HttpServer* server, |
- HttpConnection* connection, |
- const HttpServerRequestInfo& request, |
- size_t* pos) { |
- if (connection->read_buf()->GetSize() < |
- static_cast<int>(*pos + kWebSocketHandshakeBodyLen)) |
- return NULL; |
- return new WebSocketHixie76(server, connection, request, pos); |
- } |
- |
- void Accept(const HttpServerRequestInfo& request) override { |
- std::string key1 = request.GetHeaderValue("sec-websocket-key1"); |
- std::string key2 = request.GetHeaderValue("sec-websocket-key2"); |
- |
- uint32 fp1 = WebSocketKeyFingerprint(key1); |
- uint32 fp2 = WebSocketKeyFingerprint(key2); |
- |
- char data[16]; |
- memcpy(data, &fp1, 4); |
- memcpy(data + 4, &fp2, 4); |
- memcpy(data + 8, &key3_[0], 8); |
- |
- base::MD5Digest digest; |
- base::MD5Sum(data, 16, &digest); |
+WebSocket::~WebSocket() {} |
- std::string origin = request.GetHeaderValue("origin"); |
- std::string host = request.GetHeaderValue("host"); |
- std::string location = "ws://" + host + request.path; |
- server_->SendRaw( |
- connection_->id(), |
- base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" |
- "Upgrade: WebSocket\r\n" |
- "Connection: Upgrade\r\n" |
- "Sec-WebSocket-Origin: %s\r\n" |
- "Sec-WebSocket-Location: %s\r\n" |
- "\r\n", |
- origin.c_str(), |
- location.c_str())); |
- server_->SendRaw(connection_->id(), |
- std::string(reinterpret_cast<char*>(digest.a), 16)); |
- } |
- |
- ParseResult Read(std::string* message) override { |
- DCHECK(message); |
- HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); |
- if (read_buf->StartOfBuffer()[0]) |
- return FRAME_ERROR; |
- |
- base::StringPiece data(read_buf->StartOfBuffer(), read_buf->GetSize()); |
- size_t pos = data.find('\377', 1); |
- if (pos == base::StringPiece::npos) |
- return FRAME_INCOMPLETE; |
- |
- message->assign(data.data() + 1, pos - 1); |
- read_buf->DidConsume(pos + 1); |
- |
- return FRAME_OK; |
- } |
- |
- void Send(const std::string& message) override { |
- char message_start = 0; |
- char message_end = -1; |
- server_->SendRaw(connection_->id(), std::string(1, message_start)); |
- server_->SendRaw(connection_->id(), message); |
- server_->SendRaw(connection_->id(), std::string(1, message_end)); |
+WebSocket* WebSocket::CreateWebSocket(HttpServer* server, |
+ HttpConnection* connection, |
+ const HttpServerRequestInfo& request, |
+ size_t* pos) { |
+ std::string version = request.GetHeaderValue("sec-websocket-version"); |
+ if (version != "8" && version != "13") { |
+ server->SendResponse( |
+ connection->id(), |
+ HttpServerResponseInfo::CreateFor500( |
+ "Invalid request format. The version is not valid.")); |
+ return nullptr; |
} |
- private: |
- static const int kWebSocketHandshakeBodyLen; |
- |
- WebSocketHixie76(HttpServer* server, |
- HttpConnection* connection, |
- const HttpServerRequestInfo& request, |
- size_t* pos) |
- : WebSocket(server, connection) { |
- std::string key1 = request.GetHeaderValue("sec-websocket-key1"); |
- std::string key2 = request.GetHeaderValue("sec-websocket-key2"); |
- |
- if (key1.empty()) { |
- server->SendResponse( |
- connection->id(), |
- HttpServerResponseInfo::CreateFor500( |
- "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " |
- "specified.")); |
- return; |
- } |
- |
- if (key2.empty()) { |
- server->SendResponse( |
- connection->id(), |
- HttpServerResponseInfo::CreateFor500( |
- "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " |
- "specified.")); |
- return; |
- } |
- |
- key3_.assign(connection->read_buf()->StartOfBuffer() + *pos, |
- kWebSocketHandshakeBodyLen); |
- *pos += kWebSocketHandshakeBodyLen; |
+ std::string key = request.GetHeaderValue("sec-websocket-key"); |
+ if (key.empty()) { |
+ server->SendResponse( |
+ connection->id(), |
+ HttpServerResponseInfo::CreateFor500( |
+ "Invalid request format. Sec-WebSocket-Key is empty or isn't " |
+ "specified.")); |
+ return nullptr; |
} |
+ return new WebSocket(server, connection, request, pos); |
+} |
- std::string key3_; |
- |
- DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76); |
-}; |
- |
-const int WebSocketHixie76::kWebSocketHandshakeBodyLen = 8; |
- |
-class WebSocketHybi17 : public WebSocket { |
- public: |
- static WebSocket* Create(HttpServer* server, |
- HttpConnection* connection, |
- const HttpServerRequestInfo& request, |
- size_t* pos) { |
- std::string version = request.GetHeaderValue("sec-websocket-version"); |
- if (version != "8" && version != "13") |
- return NULL; |
- |
- std::string key = request.GetHeaderValue("sec-websocket-key"); |
- if (key.empty()) { |
- server->SendResponse( |
- connection->id(), |
- HttpServerResponseInfo::CreateFor500( |
- "Invalid request format. Sec-WebSocket-Key is empty or isn't " |
- "specified.")); |
- return NULL; |
- } |
- return new WebSocketHybi17(server, connection, request, pos); |
- } |
- |
- void Accept(const HttpServerRequestInfo& request) override { |
- static const char* const kWebSocketGuid = |
- "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
- std::string key = request.GetHeaderValue("sec-websocket-key"); |
- std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); |
- std::string encoded_hash; |
- base::Base64Encode(base::SHA1HashString(data), &encoded_hash); |
- |
- server_->SendRaw(connection_->id(), |
- base::StringPrintf( |
- "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" |
+void WebSocket::Accept(const HttpServerRequestInfo& request) { |
+ static const char* const kWebSocketGuid = |
+ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
+ std::string key = request.GetHeaderValue("sec-websocket-key"); |
+ std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); |
+ std::string encoded_hash; |
+ base::Base64Encode(base::SHA1HashString(data), &encoded_hash); |
+ |
+ server_->SendRaw( |
+ connection_->id(), |
+ base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" |
"Upgrade: WebSocket\r\n" |
"Connection: Upgrade\r\n" |
"Sec-WebSocket-Accept: %s\r\n" |
"%s" |
"\r\n", |
encoded_hash.c_str(), response_extensions_.c_str())); |
- } |
- |
- ParseResult Read(std::string* message) override { |
- HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); |
- base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); |
- int bytes_consumed = 0; |
- ParseResult result = encoder_->DecodeFrame(frame, &bytes_consumed, message); |
- if (result == FRAME_OK) |
- read_buf->DidConsume(bytes_consumed); |
- if (result == FRAME_CLOSE) |
- closed_ = true; |
- return result; |
- } |
- |
- void Send(const std::string& message) override { |
- if (closed_) |
- return; |
- std::string encoded; |
- encoder_->EncodeFrame(message, 0, &encoded); |
- server_->SendRaw(connection_->id(), encoded); |
- } |
- |
- private: |
- WebSocketHybi17(HttpServer* server, |
- HttpConnection* connection, |
- const HttpServerRequestInfo& request, |
- size_t* pos) |
- : WebSocket(server, connection), |
- closed_(false) { |
- std::string request_extensions = |
- request.GetHeaderValue("sec-websocket-extensions"); |
- encoder_.reset(WebSocketEncoder::CreateServer(request_extensions, |
- &response_extensions_)); |
- if (!response_extensions_.empty()) { |
- response_extensions_ = |
- "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n"; |
- } |
- } |
- |
- scoped_ptr<WebSocketEncoder> encoder_; |
- std::string response_extensions_; |
- bool closed_; |
- |
- DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17); |
-}; |
- |
-} // anonymous namespace |
- |
-WebSocket* WebSocket::CreateWebSocket(HttpServer* server, |
- HttpConnection* connection, |
- const HttpServerRequestInfo& request, |
- size_t* pos) { |
- WebSocket* socket = WebSocketHybi17::Create(server, connection, request, pos); |
- if (socket) |
- return socket; |
- |
- return WebSocketHixie76::Create(server, connection, request, pos); |
} |
-WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) |
- : server_(server), |
- connection_(connection) { |
+WebSocket::ParseResult WebSocket::Read(std::string* message) { |
+ HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); |
+ base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); |
+ int bytes_consumed = 0; |
+ ParseResult result = encoder_->DecodeFrame(frame, &bytes_consumed, message); |
+ if (result == FRAME_OK) |
+ read_buf->DidConsume(bytes_consumed); |
+ if (result == FRAME_CLOSE) |
+ closed_ = true; |
+ return result; |
} |
-WebSocket::~WebSocket() { |
+void WebSocket::Send(const std::string& message) { |
+ if (closed_) |
+ return; |
+ std::string encoded; |
+ encoder_->EncodeFrame(message, 0, &encoded); |
+ server_->SendRaw(connection_->id(), encoded); |
} |
} // namespace net |