| OLD | NEW |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/server/web_socket.h" | 5 #include "net/server/web_socket.h" |
| 6 | 6 |
| 7 #include <vector> |
| 8 |
| 7 #include "base/base64.h" | 9 #include "base/base64.h" |
| 8 #include "base/logging.h" | 10 #include "base/logging.h" |
| 9 #include "base/sha1.h" | 11 #include "base/sha1.h" |
| 10 #include "base/strings/string_number_conversions.h" | 12 #include "base/strings/string_number_conversions.h" |
| 11 #include "base/strings/stringprintf.h" | 13 #include "base/strings/stringprintf.h" |
| 12 #include "base/sys_byteorder.h" | 14 #include "base/sys_byteorder.h" |
| 13 #include "net/server/http_connection.h" | 15 #include "net/server/http_connection.h" |
| 14 #include "net/server/http_server.h" | 16 #include "net/server/http_server.h" |
| 15 #include "net/server/http_server_request_info.h" | 17 #include "net/server/http_server_request_info.h" |
| 16 #include "net/server/http_server_response_info.h" | 18 #include "net/server/http_server_response_info.h" |
| 17 #include "net/server/web_socket_encoder.h" | 19 #include "net/server/web_socket_encoder.h" |
| 20 #include "net/websockets/websocket_deflate_parameters.h" |
| 21 #include "net/websockets/websocket_extension.h" |
| 22 #include "net/websockets/websocket_handshake_constants.h" |
| 18 | 23 |
| 19 namespace net { | 24 namespace net { |
| 20 | 25 |
| 21 WebSocket::WebSocket(HttpServer* server, | 26 namespace { |
| 22 HttpConnection* connection, | 27 |
| 23 const HttpServerRequestInfo& request, | 28 std::string ExtensionsHeaderString( |
| 24 size_t* pos) | 29 const std::vector<WebSocketExtension>& extensions) { |
| 25 : server_(server), connection_(connection), closed_(false) { | 30 if (extensions.empty()) |
| 26 std::string request_extensions = | 31 return std::string(); |
| 27 request.GetHeaderValue("sec-websocket-extensions"); | 32 |
| 28 encoder_.reset(WebSocketEncoder::CreateServer(request_extensions, | 33 std::string result = "Sec-WebSocket-Extensions: " + extensions[0].ToString(); |
| 29 &response_extensions_)); | 34 for (size_t i = 1; i < extensions.size(); ++i) |
| 30 if (!response_extensions_.empty()) { | 35 result += ", " + extensions[i].ToString(); |
| 31 response_extensions_ = | 36 return result + "\r\n"; |
| 32 "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n"; | |
| 33 } | |
| 34 } | 37 } |
| 35 | 38 |
| 39 std::string ValidResponseString( |
| 40 const std::string& accept_hash, |
| 41 const std::vector<WebSocketExtension> extensions) { |
| 42 return base::StringPrintf( |
| 43 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" |
| 44 "Upgrade: WebSocket\r\n" |
| 45 "Connection: Upgrade\r\n" |
| 46 "Sec-WebSocket-Accept: %s\r\n" |
| 47 "%s" |
| 48 "\r\n", |
| 49 accept_hash.c_str(), ExtensionsHeaderString(extensions).c_str()); |
| 50 } |
| 51 |
| 52 } // namespace |
| 53 |
| 54 WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) |
| 55 : server_(server), connection_(connection), closed_(false) {} |
| 56 |
| 36 WebSocket::~WebSocket() {} | 57 WebSocket::~WebSocket() {} |
| 37 | 58 |
| 38 WebSocket* WebSocket::CreateWebSocket(HttpServer* server, | 59 void WebSocket::Accept(const HttpServerRequestInfo& request) { |
| 39 HttpConnection* connection, | |
| 40 const HttpServerRequestInfo& request, | |
| 41 size_t* pos) { | |
| 42 std::string version = request.GetHeaderValue("sec-websocket-version"); | 60 std::string version = request.GetHeaderValue("sec-websocket-version"); |
| 43 if (version != "8" && version != "13") { | 61 if (version != "8" && version != "13") { |
| 44 server->SendResponse( | 62 SendErrorResponse("Invalid request format. The version is not valid."); |
| 45 connection->id(), | 63 return; |
| 46 HttpServerResponseInfo::CreateFor500( | |
| 47 "Invalid request format. The version is not valid.")); | |
| 48 return nullptr; | |
| 49 } | 64 } |
| 50 | 65 |
| 51 std::string key = request.GetHeaderValue("sec-websocket-key"); | 66 std::string key = request.GetHeaderValue("sec-websocket-key"); |
| 52 if (key.empty()) { | 67 if (key.empty()) { |
| 53 server->SendResponse( | 68 SendErrorResponse( |
| 54 connection->id(), | 69 "Invalid request format. Sec-WebSocket-Key is empty or isn't " |
| 55 HttpServerResponseInfo::CreateFor500( | 70 "specified."); |
| 56 "Invalid request format. Sec-WebSocket-Key is empty or isn't " | 71 return; |
| 57 "specified.")); | |
| 58 return nullptr; | |
| 59 } | 72 } |
| 60 return new WebSocket(server, connection, request, pos); | 73 std::string encoded_hash; |
| 61 } | 74 base::Base64Encode(base::SHA1HashString(key + websockets::kWebSocketGuid), |
| 75 &encoded_hash); |
| 62 | 76 |
| 63 void WebSocket::Accept(const HttpServerRequestInfo& request) { | 77 std::vector<WebSocketExtension> response_extensions; |
| 64 static const char* const kWebSocketGuid = | 78 auto i = request.headers.find("sec-websocket-extensions"); |
| 65 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; | 79 if (i == request.headers.end()) { |
| 66 std::string key = request.GetHeaderValue("sec-websocket-key"); | 80 encoder_ = WebSocketEncoder::CreateServer(); |
| 67 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); | 81 } else { |
| 68 std::string encoded_hash; | 82 WebSocketDeflateParameters params; |
| 69 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); | 83 encoder_ = WebSocketEncoder::CreateServer(i->second, ¶ms); |
| 70 | 84 if (!encoder_) { |
| 71 server_->SendRaw( | 85 Fail(); |
| 72 connection_->id(), | 86 return; |
| 73 base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" | 87 } |
| 74 "Upgrade: WebSocket\r\n" | 88 response_extensions.push_back(params.AsExtension()); |
| 75 "Connection: Upgrade\r\n" | 89 } |
| 76 "Sec-WebSocket-Accept: %s\r\n" | 90 server_->SendRaw(connection_->id(), |
| 77 "%s" | 91 ValidResponseString(encoded_hash, response_extensions)); |
| 78 "\r\n", | |
| 79 encoded_hash.c_str(), response_extensions_.c_str())); | |
| 80 } | 92 } |
| 81 | 93 |
| 82 WebSocket::ParseResult WebSocket::Read(std::string* message) { | 94 WebSocket::ParseResult WebSocket::Read(std::string* message) { |
| 95 if (closed_) |
| 96 return FRAME_CLOSE; |
| 97 |
| 83 HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); | 98 HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); |
| 84 base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); | 99 base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); |
| 85 int bytes_consumed = 0; | 100 int bytes_consumed = 0; |
| 86 ParseResult result = encoder_->DecodeFrame(frame, &bytes_consumed, message); | 101 ParseResult result = encoder_->DecodeFrame(frame, &bytes_consumed, message); |
| 87 if (result == FRAME_OK) | 102 if (result == FRAME_OK) |
| 88 read_buf->DidConsume(bytes_consumed); | 103 read_buf->DidConsume(bytes_consumed); |
| 89 if (result == FRAME_CLOSE) | 104 if (result == FRAME_CLOSE) |
| 90 closed_ = true; | 105 closed_ = true; |
| 91 return result; | 106 return result; |
| 92 } | 107 } |
| 93 | 108 |
| 94 void WebSocket::Send(const std::string& message) { | 109 void WebSocket::Send(const std::string& message) { |
| 95 if (closed_) | 110 if (closed_) |
| 96 return; | 111 return; |
| 97 std::string encoded; | 112 std::string encoded; |
| 98 encoder_->EncodeFrame(message, 0, &encoded); | 113 encoder_->EncodeFrame(message, 0, &encoded); |
| 99 server_->SendRaw(connection_->id(), encoded); | 114 server_->SendRaw(connection_->id(), encoded); |
| 100 } | 115 } |
| 101 | 116 |
| 117 void WebSocket::Fail() { |
| 118 closed_ = true; |
| 119 // TODO(yhirano): The server SHOULD log the problem. |
| 120 server_->Close(connection_->id()); |
| 121 } |
| 122 |
| 123 void WebSocket::SendErrorResponse(const std::string& message) { |
| 124 if (closed_) |
| 125 return; |
| 126 closed_ = true; |
| 127 server_->Send500(connection_->id(), message); |
| 128 } |
| 129 |
| 102 } // namespace net | 130 } // namespace net |
| OLD | NEW |