Index: net/server/web_socket.cc |
diff --git a/net/server/web_socket.cc b/net/server/web_socket.cc |
index c9637450d4a1ab9aa8b584dde85fff60293eb26b..79ffcecb7ada589c1c3171c27caedea4226931d4 100644 |
--- a/net/server/web_socket.cc |
+++ b/net/server/web_socket.cc |
@@ -4,6 +4,8 @@ |
#include "net/server/web_socket.h" |
+#include <vector> |
+ |
#include "base/base64.h" |
#include "base/logging.h" |
#include "base/sha1.h" |
@@ -15,70 +17,87 @@ |
#include "net/server/http_server_request_info.h" |
#include "net/server/http_server_response_info.h" |
#include "net/server/web_socket_encoder.h" |
+#include "net/websockets/websocket_deflate_parameters.h" |
+#include "net/websockets/websocket_extension.h" |
+#include "net/websockets/websocket_handshake_constants.h" |
namespace net { |
-WebSocket::WebSocket(HttpServer* server, |
- HttpConnection* connection, |
- const HttpServerRequestInfo& request) |
- : server_(server), connection_(connection), closed_(false) { |
- std::string request_extensions = |
- request.GetHeaderValue("sec-websocket-extensions"); |
- encoder_.reset(WebSocketEncoder::CreateServer(request_extensions, |
- &response_extensions_)); |
- if (!response_extensions_.empty()) { |
- response_extensions_ = |
- "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n"; |
- } |
+namespace { |
+ |
+std::string ExtensionsHeaderString( |
+ const std::vector<WebSocketExtension>& extensions) { |
+ if (extensions.empty()) |
+ return std::string(); |
+ |
+ std::string result = "Sec-WebSocket-Extensions: " + extensions[0].ToString(); |
+ for (size_t i = 1; i < extensions.size(); ++i) |
+ result += ", " + extensions[i].ToString(); |
+ return result + "\r\n"; |
} |
+std::string ValidResponseString( |
+ const std::string& accept_hash, |
+ const std::vector<WebSocketExtension> extensions) { |
+ return base::StringPrintf( |
+ "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" |
+ "Upgrade: WebSocket\r\n" |
+ "Connection: Upgrade\r\n" |
+ "Sec-WebSocket-Accept: %s\r\n" |
+ "%s" |
+ "\r\n", |
+ accept_hash.c_str(), ExtensionsHeaderString(extensions).c_str()); |
+} |
+ |
+} // namespace |
+ |
+WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) |
+ : server_(server), connection_(connection), closed_(false) {} |
+ |
WebSocket::~WebSocket() {} |
-scoped_ptr<WebSocket> WebSocket::CreateWebSocket( |
- HttpServer* server, |
- HttpConnection* connection, |
- const HttpServerRequestInfo& request) { |
+void WebSocket::Accept(const HttpServerRequestInfo& request) { |
std::string version = request.GetHeaderValue("sec-websocket-version"); |
if (version != "8" && version != "13") { |
- server->SendResponse( |
- connection->id(), |
- HttpServerResponseInfo::CreateFor500( |
- "Invalid request format. The version is not valid.")); |
- return nullptr; |
+ SendErrorResponse("Invalid request format. The version is not valid."); |
+ return; |
} |
std::string key = request.GetHeaderValue("sec-websocket-key"); |
if (key.empty()) { |
- server->SendResponse( |
- connection->id(), |
- HttpServerResponseInfo::CreateFor500( |
- "Invalid request format. Sec-WebSocket-Key is empty or isn't " |
- "specified.")); |
- return nullptr; |
+ SendErrorResponse( |
+ "Invalid request format. Sec-WebSocket-Key is empty or isn't " |
+ "specified."); |
+ return; |
} |
- return make_scoped_ptr(new WebSocket(server, connection, request)); |
-} |
- |
-void WebSocket::Accept(const HttpServerRequestInfo& request) { |
- static const char* const kWebSocketGuid = |
- "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
- std::string key = request.GetHeaderValue("sec-websocket-key"); |
- std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); |
std::string encoded_hash; |
- base::Base64Encode(base::SHA1HashString(data), &encoded_hash); |
- |
- server_->SendRaw( |
- connection_->id(), |
- base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" |
- "Upgrade: WebSocket\r\n" |
- "Connection: Upgrade\r\n" |
- "Sec-WebSocket-Accept: %s\r\n" |
- "%s" |
- "\r\n", |
- encoded_hash.c_str(), response_extensions_.c_str())); |
+ base::Base64Encode(base::SHA1HashString(key + websockets::kWebSocketGuid), |
+ &encoded_hash); |
+ |
+ std::vector<WebSocketExtension> response_extensions; |
+ auto i = request.headers.find("sec-websocket-extensions"); |
+ if (i == request.headers.end()) { |
+ encoder_ = WebSocketEncoder::CreateServer(); |
+ } else { |
+ WebSocketDeflateParameters params; |
+ encoder_ = WebSocketEncoder::CreateServer(i->second, ¶ms); |
+ if (!encoder_) { |
+ Fail(); |
+ return; |
+ } |
+ if (encoder_->deflate_enabled()) { |
+ DCHECK(params.IsValidAsResponse()); |
+ response_extensions.push_back(params.AsExtension()); |
+ } |
+ } |
+ server_->SendRaw(connection_->id(), |
+ ValidResponseString(encoded_hash, response_extensions)); |
} |
WebSocket::ParseResult WebSocket::Read(std::string* message) { |
+ if (closed_) |
+ return FRAME_CLOSE; |
+ |
HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); |
base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); |
int bytes_consumed = 0; |
@@ -98,4 +117,17 @@ void WebSocket::Send(const std::string& message) { |
server_->SendRaw(connection_->id(), encoded); |
} |
+void WebSocket::Fail() { |
+ closed_ = true; |
+ // TODO(yhirano): The server SHOULD log the problem. |
+ server_->Close(connection_->id()); |
+} |
+ |
+void WebSocket::SendErrorResponse(const std::string& message) { |
+ if (closed_) |
+ return; |
+ closed_ = true; |
+ server_->Send500(connection_->id(), message); |
+} |
+ |
} // namespace net |