| OLD | NEW |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/server/web_socket.h" | 5 #include "net/server/web_socket.h" |
| 6 | 6 |
| 7 #include "base/base64.h" | 7 #include "base/base64.h" |
| 8 #include "base/logging.h" | 8 #include "base/logging.h" |
| 9 #include "base/md5.h" | |
| 10 #include "base/sha1.h" | 9 #include "base/sha1.h" |
| 11 #include "base/strings/string_number_conversions.h" | 10 #include "base/strings/string_number_conversions.h" |
| 12 #include "base/strings/stringprintf.h" | 11 #include "base/strings/stringprintf.h" |
| 13 #include "base/sys_byteorder.h" | 12 #include "base/sys_byteorder.h" |
| 14 #include "net/server/http_connection.h" | 13 #include "net/server/http_connection.h" |
| 15 #include "net/server/http_server.h" | 14 #include "net/server/http_server.h" |
| 16 #include "net/server/http_server_request_info.h" | 15 #include "net/server/http_server_request_info.h" |
| 17 #include "net/server/http_server_response_info.h" | 16 #include "net/server/http_server_response_info.h" |
| 18 #include "net/server/web_socket_encoder.h" | 17 #include "net/server/web_socket_encoder.h" |
| 19 | 18 |
| 20 namespace net { | 19 namespace net { |
| 21 | 20 |
| 22 namespace { | 21 WebSocket::WebSocket(HttpServer* server, |
| 23 | 22 HttpConnection* connection, |
| 24 static uint32 WebSocketKeyFingerprint(const std::string& str) { | 23 const HttpServerRequestInfo& request, |
| 25 std::string result; | 24 size_t* pos) |
| 26 const char* p_char = str.c_str(); | 25 : server_(server), connection_(connection), closed_(false) { |
| 27 int length = str.length(); | 26 std::string request_extensions = |
| 28 int spaces = 0; | 27 request.GetHeaderValue("sec-websocket-extensions"); |
| 29 for (int i = 0; i < length; ++i) { | 28 encoder_.reset(WebSocketEncoder::CreateServer(request_extensions, |
| 30 if (p_char[i] >= '0' && p_char[i] <= '9') | 29 &response_extensions_)); |
| 31 result.append(&p_char[i], 1); | 30 if (!response_extensions_.empty()) { |
| 32 else if (p_char[i] == ' ') | 31 response_extensions_ = |
| 33 spaces++; | 32 "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n"; |
| 34 } | 33 } |
| 35 if (spaces == 0) | |
| 36 return 0; | |
| 37 int64 number = 0; | |
| 38 if (!base::StringToInt64(result, &number)) | |
| 39 return 0; | |
| 40 return base::HostToNet32(static_cast<uint32>(number / spaces)); | |
| 41 } | 34 } |
| 42 | 35 |
| 43 class WebSocketHixie76 : public WebSocket { | 36 WebSocket::~WebSocket() {} |
| 44 public: | 37 |
| 45 static WebSocket* Create(HttpServer* server, | 38 WebSocket* WebSocket::CreateWebSocket(HttpServer* server, |
| 46 HttpConnection* connection, | 39 HttpConnection* connection, |
| 47 const HttpServerRequestInfo& request, | 40 const HttpServerRequestInfo& request, |
| 48 size_t* pos) { | 41 size_t* pos) { |
| 49 if (connection->read_buf()->GetSize() < | 42 std::string version = request.GetHeaderValue("sec-websocket-version"); |
| 50 static_cast<int>(*pos + kWebSocketHandshakeBodyLen)) | 43 if (version != "8" && version != "13") { |
| 51 return NULL; | 44 server->SendResponse( |
| 52 return new WebSocketHixie76(server, connection, request, pos); | 45 connection->id(), |
| 46 HttpServerResponseInfo::CreateFor500( |
| 47 "Invalid request format. The version is not valid.")); |
| 48 return nullptr; |
| 53 } | 49 } |
| 54 | 50 |
| 55 void Accept(const HttpServerRequestInfo& request) override { | 51 std::string key = request.GetHeaderValue("sec-websocket-key"); |
| 56 std::string key1 = request.GetHeaderValue("sec-websocket-key1"); | 52 if (key.empty()) { |
| 57 std::string key2 = request.GetHeaderValue("sec-websocket-key2"); | 53 server->SendResponse( |
| 54 connection->id(), |
| 55 HttpServerResponseInfo::CreateFor500( |
| 56 "Invalid request format. Sec-WebSocket-Key is empty or isn't " |
| 57 "specified.")); |
| 58 return nullptr; |
| 59 } |
| 60 return new WebSocket(server, connection, request, pos); |
| 61 } |
| 58 | 62 |
| 59 uint32 fp1 = WebSocketKeyFingerprint(key1); | 63 void WebSocket::Accept(const HttpServerRequestInfo& request) { |
| 60 uint32 fp2 = WebSocketKeyFingerprint(key2); | 64 static const char* const kWebSocketGuid = |
| 65 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; |
| 66 std::string key = request.GetHeaderValue("sec-websocket-key"); |
| 67 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); |
| 68 std::string encoded_hash; |
| 69 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); |
| 61 | 70 |
| 62 char data[16]; | 71 server_->SendRaw( |
| 63 memcpy(data, &fp1, 4); | 72 connection_->id(), |
| 64 memcpy(data + 4, &fp2, 4); | 73 base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" |
| 65 memcpy(data + 8, &key3_[0], 8); | |
| 66 | |
| 67 base::MD5Digest digest; | |
| 68 base::MD5Sum(data, 16, &digest); | |
| 69 | |
| 70 std::string origin = request.GetHeaderValue("origin"); | |
| 71 std::string host = request.GetHeaderValue("host"); | |
| 72 std::string location = "ws://" + host + request.path; | |
| 73 server_->SendRaw( | |
| 74 connection_->id(), | |
| 75 base::StringPrintf("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" | |
| 76 "Upgrade: WebSocket\r\n" | |
| 77 "Connection: Upgrade\r\n" | |
| 78 "Sec-WebSocket-Origin: %s\r\n" | |
| 79 "Sec-WebSocket-Location: %s\r\n" | |
| 80 "\r\n", | |
| 81 origin.c_str(), | |
| 82 location.c_str())); | |
| 83 server_->SendRaw(connection_->id(), | |
| 84 std::string(reinterpret_cast<char*>(digest.a), 16)); | |
| 85 } | |
| 86 | |
| 87 ParseResult Read(std::string* message) override { | |
| 88 DCHECK(message); | |
| 89 HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); | |
| 90 if (read_buf->StartOfBuffer()[0]) | |
| 91 return FRAME_ERROR; | |
| 92 | |
| 93 base::StringPiece data(read_buf->StartOfBuffer(), read_buf->GetSize()); | |
| 94 size_t pos = data.find('\377', 1); | |
| 95 if (pos == base::StringPiece::npos) | |
| 96 return FRAME_INCOMPLETE; | |
| 97 | |
| 98 message->assign(data.data() + 1, pos - 1); | |
| 99 read_buf->DidConsume(pos + 1); | |
| 100 | |
| 101 return FRAME_OK; | |
| 102 } | |
| 103 | |
| 104 void Send(const std::string& message) override { | |
| 105 char message_start = 0; | |
| 106 char message_end = -1; | |
| 107 server_->SendRaw(connection_->id(), std::string(1, message_start)); | |
| 108 server_->SendRaw(connection_->id(), message); | |
| 109 server_->SendRaw(connection_->id(), std::string(1, message_end)); | |
| 110 } | |
| 111 | |
| 112 private: | |
| 113 static const int kWebSocketHandshakeBodyLen; | |
| 114 | |
| 115 WebSocketHixie76(HttpServer* server, | |
| 116 HttpConnection* connection, | |
| 117 const HttpServerRequestInfo& request, | |
| 118 size_t* pos) | |
| 119 : WebSocket(server, connection) { | |
| 120 std::string key1 = request.GetHeaderValue("sec-websocket-key1"); | |
| 121 std::string key2 = request.GetHeaderValue("sec-websocket-key2"); | |
| 122 | |
| 123 if (key1.empty()) { | |
| 124 server->SendResponse( | |
| 125 connection->id(), | |
| 126 HttpServerResponseInfo::CreateFor500( | |
| 127 "Invalid request format. Sec-WebSocket-Key1 is empty or isn't " | |
| 128 "specified.")); | |
| 129 return; | |
| 130 } | |
| 131 | |
| 132 if (key2.empty()) { | |
| 133 server->SendResponse( | |
| 134 connection->id(), | |
| 135 HttpServerResponseInfo::CreateFor500( | |
| 136 "Invalid request format. Sec-WebSocket-Key2 is empty or isn't " | |
| 137 "specified.")); | |
| 138 return; | |
| 139 } | |
| 140 | |
| 141 key3_.assign(connection->read_buf()->StartOfBuffer() + *pos, | |
| 142 kWebSocketHandshakeBodyLen); | |
| 143 *pos += kWebSocketHandshakeBodyLen; | |
| 144 } | |
| 145 | |
| 146 std::string key3_; | |
| 147 | |
| 148 DISALLOW_COPY_AND_ASSIGN(WebSocketHixie76); | |
| 149 }; | |
| 150 | |
| 151 const int WebSocketHixie76::kWebSocketHandshakeBodyLen = 8; | |
| 152 | |
| 153 class WebSocketHybi17 : public WebSocket { | |
| 154 public: | |
| 155 static WebSocket* Create(HttpServer* server, | |
| 156 HttpConnection* connection, | |
| 157 const HttpServerRequestInfo& request, | |
| 158 size_t* pos) { | |
| 159 std::string version = request.GetHeaderValue("sec-websocket-version"); | |
| 160 if (version != "8" && version != "13") | |
| 161 return NULL; | |
| 162 | |
| 163 std::string key = request.GetHeaderValue("sec-websocket-key"); | |
| 164 if (key.empty()) { | |
| 165 server->SendResponse( | |
| 166 connection->id(), | |
| 167 HttpServerResponseInfo::CreateFor500( | |
| 168 "Invalid request format. Sec-WebSocket-Key is empty or isn't " | |
| 169 "specified.")); | |
| 170 return NULL; | |
| 171 } | |
| 172 return new WebSocketHybi17(server, connection, request, pos); | |
| 173 } | |
| 174 | |
| 175 void Accept(const HttpServerRequestInfo& request) override { | |
| 176 static const char* const kWebSocketGuid = | |
| 177 "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; | |
| 178 std::string key = request.GetHeaderValue("sec-websocket-key"); | |
| 179 std::string data = base::StringPrintf("%s%s", key.c_str(), kWebSocketGuid); | |
| 180 std::string encoded_hash; | |
| 181 base::Base64Encode(base::SHA1HashString(data), &encoded_hash); | |
| 182 | |
| 183 server_->SendRaw(connection_->id(), | |
| 184 base::StringPrintf( | |
| 185 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" | |
| 186 "Upgrade: WebSocket\r\n" | 74 "Upgrade: WebSocket\r\n" |
| 187 "Connection: Upgrade\r\n" | 75 "Connection: Upgrade\r\n" |
| 188 "Sec-WebSocket-Accept: %s\r\n" | 76 "Sec-WebSocket-Accept: %s\r\n" |
| 189 "%s" | 77 "%s" |
| 190 "\r\n", | 78 "\r\n", |
| 191 encoded_hash.c_str(), response_extensions_.c_str())); | 79 encoded_hash.c_str(), response_extensions_.c_str())); |
| 192 } | |
| 193 | |
| 194 ParseResult Read(std::string* message) override { | |
| 195 HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); | |
| 196 base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); | |
| 197 int bytes_consumed = 0; | |
| 198 ParseResult result = encoder_->DecodeFrame(frame, &bytes_consumed, message); | |
| 199 if (result == FRAME_OK) | |
| 200 read_buf->DidConsume(bytes_consumed); | |
| 201 if (result == FRAME_CLOSE) | |
| 202 closed_ = true; | |
| 203 return result; | |
| 204 } | |
| 205 | |
| 206 void Send(const std::string& message) override { | |
| 207 if (closed_) | |
| 208 return; | |
| 209 std::string encoded; | |
| 210 encoder_->EncodeFrame(message, 0, &encoded); | |
| 211 server_->SendRaw(connection_->id(), encoded); | |
| 212 } | |
| 213 | |
| 214 private: | |
| 215 WebSocketHybi17(HttpServer* server, | |
| 216 HttpConnection* connection, | |
| 217 const HttpServerRequestInfo& request, | |
| 218 size_t* pos) | |
| 219 : WebSocket(server, connection), | |
| 220 closed_(false) { | |
| 221 std::string request_extensions = | |
| 222 request.GetHeaderValue("sec-websocket-extensions"); | |
| 223 encoder_.reset(WebSocketEncoder::CreateServer(request_extensions, | |
| 224 &response_extensions_)); | |
| 225 if (!response_extensions_.empty()) { | |
| 226 response_extensions_ = | |
| 227 "Sec-WebSocket-Extensions: " + response_extensions_ + "\r\n"; | |
| 228 } | |
| 229 } | |
| 230 | |
| 231 scoped_ptr<WebSocketEncoder> encoder_; | |
| 232 std::string response_extensions_; | |
| 233 bool closed_; | |
| 234 | |
| 235 DISALLOW_COPY_AND_ASSIGN(WebSocketHybi17); | |
| 236 }; | |
| 237 | |
| 238 } // anonymous namespace | |
| 239 | |
| 240 WebSocket* WebSocket::CreateWebSocket(HttpServer* server, | |
| 241 HttpConnection* connection, | |
| 242 const HttpServerRequestInfo& request, | |
| 243 size_t* pos) { | |
| 244 WebSocket* socket = WebSocketHybi17::Create(server, connection, request, pos); | |
| 245 if (socket) | |
| 246 return socket; | |
| 247 | |
| 248 return WebSocketHixie76::Create(server, connection, request, pos); | |
| 249 } | 80 } |
| 250 | 81 |
| 251 WebSocket::WebSocket(HttpServer* server, HttpConnection* connection) | 82 WebSocket::ParseResult WebSocket::Read(std::string* message) { |
| 252 : server_(server), | 83 HttpConnection::ReadIOBuffer* read_buf = connection_->read_buf(); |
| 253 connection_(connection) { | 84 base::StringPiece frame(read_buf->StartOfBuffer(), read_buf->GetSize()); |
| 85 int bytes_consumed = 0; |
| 86 ParseResult result = encoder_->DecodeFrame(frame, &bytes_consumed, message); |
| 87 if (result == FRAME_OK) |
| 88 read_buf->DidConsume(bytes_consumed); |
| 89 if (result == FRAME_CLOSE) |
| 90 closed_ = true; |
| 91 return result; |
| 254 } | 92 } |
| 255 | 93 |
| 256 WebSocket::~WebSocket() { | 94 void WebSocket::Send(const std::string& message) { |
| 95 if (closed_) |
| 96 return; |
| 97 std::string encoded; |
| 98 encoder_->EncodeFrame(message, 0, &encoded); |
| 99 server_->SendRaw(connection_->id(), encoded); |
| 257 } | 100 } |
| 258 | 101 |
| 259 } // namespace net | 102 } // namespace net |
| OLD | NEW |