Chromium Code Reviews| Index: remoting/host/gnubby_auth_handler_posix.cc |
| diff --git a/remoting/host/gnubby_auth_handler_posix.cc b/remoting/host/gnubby_auth_handler_posix.cc |
| index c797cce46f4eb31854b9c2bbf8cf47d25b61d092..9df548c114e527a68798b761d734167f177ee598 100644 |
| --- a/remoting/host/gnubby_auth_handler_posix.cc |
| +++ b/remoting/host/gnubby_auth_handler_posix.cc |
| @@ -6,8 +6,8 @@ |
| #include <unistd.h> |
| #include <utility> |
| -#include <vector> |
| +#include "base/base64.h" |
| #include "base/bind.h" |
| #include "base/file_util.h" |
| #include "base/json/json_reader.h" |
| @@ -17,7 +17,7 @@ |
| #include "base/values.h" |
| #include "net/socket/unix_domain_socket_posix.h" |
| #include "remoting/base/logging.h" |
| -#include "remoting/host/gnubby_util.h" |
| +#include "remoting/host/gnubby_socket.h" |
| #include "remoting/proto/control.pb.h" |
| #include "remoting/protocol/client_stub.h" |
| @@ -25,15 +25,14 @@ namespace remoting { |
| namespace { |
| -const int kMaxRequestLength = 4096; |
| - |
| +const char kBase64Data[] = "base64Data"; |
| const char kConnectionId[] = "connectionId"; |
| const char kControlMessage[] = "control"; |
| const char kControlOption[] = "option"; |
| const char kDataMessage[] = "data"; |
| +const char kErrorMessage[] = "error"; |
| const char kGnubbyAuthMessage[] = "gnubby-auth"; |
| const char kGnubbyAuthV1[] = "auth-v1"; |
| -const char kJSONMessage[] = "jsonMessage"; |
| const char kMessageType[] = "type"; |
| // The name of the socket to listen for gnubby requests on. |
| @@ -45,9 +44,8 @@ class CompareSocket { |
| public: |
| explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {} |
| - bool operator()(const std::pair<int, net::StreamListenSocket*> element) |
| - const { |
| - return socket_ == element.second; |
| + bool operator()(const std::pair<int, GnubbySocket*> element) const { |
| + return element.second->IsSocket(socket_); |
| } |
| private: |
| @@ -63,25 +61,8 @@ bool MatchUid(uid_t user_id, gid_t) { |
| return allowed; |
| } |
| -// Returns the request data length from the first four data bytes. |
| -int GetRequestLength(const char* data) { |
| - return ((data[0] & 255) << 24) + ((data[1] & 255) << 16) + |
| - ((data[2] & 255) << 8) + (data[3] & 255) + 4; |
| -} |
| - |
| -// Returns true if the request data is complete (has at least as many bytes as |
| -// indicated by the size in the first four bytes plus four for the first bytes). |
| -bool IsRequestComplete(const char* data, int data_len) { |
| - if (data_len < 4) |
| - return false; |
| - return GetRequestLength(data) <= data_len; |
| -} |
| - |
| -// Returns true if the request data size is bigger than the threshold. |
| -bool IsRequestTooLarge(const char* data, int data_len, int max_len) { |
| - if (data_len < 4) |
| - return false; |
| - return GetRequestLength(data) > max_len; |
| +int GetCommandCode(const std::string& data) { |
| + return data.empty() ? -1 : static_cast<int>(data[0]); |
| } |
| } // namespace |
| @@ -129,22 +110,28 @@ void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) { |
| LOG(ERROR) << "Invalid gnubby-auth control option"; |
| } |
| } else if (type == kDataMessage) { |
| - int connection_id; |
| - std::string json_message; |
| - if (client_message->GetInteger(kConnectionId, &connection_id) && |
| - client_message->GetString(kJSONMessage, &json_message)) { |
| - ActiveSockets::iterator iter = active_sockets_.find(connection_id); |
| - if (iter != active_sockets_.end()) { |
| - HOST_LOG << "Sending gnubby response"; |
| - |
| - std::string response; |
| - GetGnubbyResponseFromJson(json_message, &response); |
| - iter->second->Send(response); |
| + ActiveSockets::iterator iter = GetActiveSocket(client_message); |
| + if (iter != active_sockets_.end()) { |
| + std::string base64_data; |
| + std::string response; |
| + if (client_message->GetString(kBase64Data, &base64_data) && |
| + base::Base64Decode(base64_data, &response)) { |
| + HOST_LOG << "Sending gnubby response: " << GetCommandCode(response); |
| + iter->second->SendResponse(response); |
| } else { |
| - LOG(ERROR) << "Received gnubby-auth data for unknown connection"; |
| + LOG(ERROR) << "Invalid gnubby data"; |
| + SendErrorAndCloseActiveSocket(iter); |
| } |
| } else { |
| - LOG(ERROR) << "Invalid gnubby-auth data message"; |
| + LOG(ERROR) << "Unknown gnubby-auth data connection"; |
| + } |
| + } else if (type == kErrorMessage) { |
| + ActiveSockets::iterator iter = GetActiveSocket(client_message); |
| + if (iter != active_sockets_.end()) { |
| + HOST_LOG << "Sending gnubby error"; |
| + SendErrorAndCloseActiveSocket(iter); |
| + } else { |
| + LOG(ERROR) << "Unknown gnubby-auth error connection"; |
| } |
| } else { |
| LOG(ERROR) << "Unknown gnubby-auth message type: " << type; |
| @@ -152,15 +139,15 @@ void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) { |
| } |
| } |
| -void GnubbyAuthHandlerPosix::DeliverHostDataMessage(int connection_id, |
| - const std::string& data) |
| - const { |
| +void GnubbyAuthHandlerPosix::DeliverHostDataMessage( |
| + int connection_id, |
| + const std::string& base64_data) const { |
| DCHECK(CalledOnValidThread()); |
| base::DictionaryValue request; |
| request.SetString(kMessageType, kDataMessage); |
| request.SetInteger(kConnectionId, connection_id); |
| - request.SetString(kJSONMessage, data); |
| + request.SetString(kBase64Data, base64_data); |
| std::string request_json; |
| if (!base::JSONWriter::Write(&request, &request_json)) { |
| @@ -182,12 +169,31 @@ bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting( |
| CompareSocket(socket)) != active_sockets_.end(); |
| } |
| +int GnubbyAuthHandlerPosix::GetConnectionIdForTesting( |
| + net::StreamListenSocket* socket) const { |
| + ActiveSockets::const_iterator iter = std::find_if( |
| + active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
| + return iter->first; |
| +} |
| + |
| +GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting( |
| + net::StreamListenSocket* socket) const { |
| + ActiveSockets::const_iterator iter = std::find_if( |
| + active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
| + return iter->second; |
| +} |
| + |
| void GnubbyAuthHandlerPosix::DidAccept( |
| net::StreamListenSocket* server, |
| scoped_ptr<net::StreamListenSocket> socket) { |
| DCHECK(CalledOnValidThread()); |
| - active_sockets_[++last_connection_id_] = socket.release(); |
| + int connection_id = ++last_connection_id_; |
| + active_sockets_[connection_id] = |
| + new GnubbySocket(socket.Pass(), |
| + base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut, |
| + base::Unretained(this), |
| + connection_id)); |
| } |
| void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket, |
| @@ -195,39 +201,20 @@ void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket, |
| int len) { |
| DCHECK(CalledOnValidThread()); |
| - ActiveSockets::iterator socket_iter = std::find_if( |
| + ActiveSockets::iterator iter = std::find_if( |
| active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
| - if (socket_iter != active_sockets_.end()) { |
| - int connection_id = socket_iter->first; |
| - |
| - ActiveRequests::iterator request_iter = |
| - active_requests_.find(connection_id); |
| - if (request_iter != active_requests_.end()) { |
| - std::vector<char>& saved_vector = request_iter->second; |
| - if (IsRequestTooLarge( |
| - saved_vector.data(), saved_vector.size(), kMaxRequestLength)) { |
| - // We can't close a StreamListenSocket; throw away everything but the |
| - // size bytes. |
| - saved_vector.resize(4); |
| - return; |
| - } |
| - saved_vector.insert(saved_vector.end(), data, data + len); |
| - |
| - if (IsRequestComplete(saved_vector.data(), saved_vector.size())) { |
| - ProcessGnubbyRequest( |
| - connection_id, saved_vector.data(), saved_vector.size()); |
| - active_requests_.erase(request_iter); |
| - } |
| - } else if (IsRequestComplete(data, len)) { |
| - ProcessGnubbyRequest(connection_id, data, len); |
| - } else { |
| - if (IsRequestTooLarge(data, len, kMaxRequestLength)) { |
| - // Only save the size bytes. |
| - active_requests_[connection_id] = std::vector<char>(data, data + 4); |
| - } else { |
| - active_requests_[connection_id] = std::vector<char>(data, data + len); |
| - } |
| + if (iter != active_sockets_.end()) { |
| + GnubbySocket* gnubby_socket = iter->second; |
| + gnubby_socket->AddRequestData(data, len); |
| + if (gnubby_socket->IsRequestTooLarge()) { |
| + SendErrorAndCloseActiveSocket(iter); |
| + } else if (gnubby_socket->IsRequestComplete()) { |
| + std::string request_data; |
| + gnubby_socket->GetAndClearRequestData(&request_data); |
| + ProcessGnubbyRequest(iter->first, request_data); |
| } |
| + } else { |
| + LOG(ERROR) << "Received data for unknown connection"; |
| } |
| } |
| @@ -237,8 +224,6 @@ void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) { |
| ActiveSockets::iterator iter = std::find_if( |
| active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
| if (iter != active_sockets_.end()) { |
| - active_requests_.erase(iter->first); |
| - |
| delete iter->second; |
| active_sockets_.erase(iter); |
| } |
| @@ -264,16 +249,37 @@ void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() { |
| } |
| } |
| -void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(int connection_id, |
| - const char* data, |
| - int data_len) { |
| - std::string json; |
| - if (GetJsonFromGnubbyRequest(data, data_len, &json)) { |
| - HOST_LOG << "Received gnubby request"; |
| - DeliverHostDataMessage(connection_id, json); |
| - } else { |
| - LOG(ERROR) << "Could not decode gnubby request"; |
| +void GnubbyAuthHandlerPosix::ProcessGnubbyRequest( |
| + int connection_id, |
| + const std::string& request_data) { |
| + HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data); |
| + std::string base64_request; |
| + base::Base64Encode(request_data, &base64_request); |
| + DeliverHostDataMessage(connection_id, base64_request); |
| +} |
| + |
| +GnubbyAuthHandlerPosix::ActiveSockets::iterator |
| +GnubbyAuthHandlerPosix::GetActiveSocket(base::DictionaryValue* dictionary) { |
|
Sergey Ulanov
2014/03/21 02:09:37
Maybe GetSocketForMessage()? Also rename dictionar
psj
2014/03/21 21:30:45
Done.
|
| + int connection_id; |
| + if (dictionary->GetInteger(kConnectionId, &connection_id)) { |
| + return active_sockets_.find(connection_id); |
| } |
| + return active_sockets_.end(); |
| +} |
| + |
| +void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket( |
| + const ActiveSockets::iterator& iter) { |
| + if (iter != active_sockets_.end()) { |
| + iter->second->SendSshError(); |
| + |
| + delete iter->second; |
| + active_sockets_.erase(iter); |
| + } |
| +} |
| + |
| +void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) { |
| + LOG(ERROR) << "Request timed out"; |
|
Sergey Ulanov
2014/03/21 02:09:37
HOST_LOG please. Also update the message to indica
psj
2014/03/21 21:30:45
Done.
|
| + SendErrorAndCloseActiveSocket(active_sockets_.find(connection_id)); |
| } |
| } // namespace remoting |