Index: remoting/host/gnubby_auth_handler_posix.cc |
diff --git a/remoting/host/gnubby_auth_handler_posix.cc b/remoting/host/gnubby_auth_handler_posix.cc |
index c797cce46f4eb31854b9c2bbf8cf47d25b61d092..77938e078b70c920a03c21e1ab3532d4cfe020fd 100644 |
--- a/remoting/host/gnubby_auth_handler_posix.cc |
+++ b/remoting/host/gnubby_auth_handler_posix.cc |
@@ -6,7 +6,6 @@ |
#include <unistd.h> |
#include <utility> |
-#include <vector> |
#include "base/bind.h" |
#include "base/file_util.h" |
@@ -17,7 +16,7 @@ |
#include "base/values.h" |
#include "net/socket/unix_domain_socket_posix.h" |
#include "remoting/base/logging.h" |
-#include "remoting/host/gnubby_util.h" |
+#include "remoting/host/gnubby_socket.h" |
#include "remoting/proto/control.pb.h" |
#include "remoting/protocol/client_stub.h" |
@@ -25,15 +24,14 @@ namespace remoting { |
namespace { |
-const int kMaxRequestLength = 4096; |
- |
const char kConnectionId[] = "connectionId"; |
const char kControlMessage[] = "control"; |
const char kControlOption[] = "option"; |
const char kDataMessage[] = "data"; |
+const char kDataPayload[] = "data"; |
+const char kErrorMessage[] = "error"; |
const char kGnubbyAuthMessage[] = "gnubby-auth"; |
const char kGnubbyAuthV1[] = "auth-v1"; |
-const char kJSONMessage[] = "jsonMessage"; |
const char kMessageType[] = "type"; |
// The name of the socket to listen for gnubby requests on. |
@@ -45,9 +43,8 @@ class CompareSocket { |
public: |
explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {} |
- bool operator()(const std::pair<int, net::StreamListenSocket*> element) |
- const { |
- return socket_ == element.second; |
+ bool operator()(const std::pair<int, GnubbySocket*> element) const { |
+ return element.second->IsSocket(socket_); |
} |
private: |
@@ -63,25 +60,28 @@ bool MatchUid(uid_t user_id, gid_t) { |
return allowed; |
} |
-// Returns the request data length from the first four data bytes. |
-int GetRequestLength(const char* data) { |
- return ((data[0] & 255) << 24) + ((data[1] & 255) << 16) + |
- ((data[2] & 255) << 8) + (data[3] & 255) + 4; |
-} |
- |
-// Returns true if the request data is complete (has at least as many bytes as |
-// indicated by the size in the first four bytes plus four for the first bytes). |
-bool IsRequestComplete(const char* data, int data_len) { |
- if (data_len < 4) |
- return false; |
- return GetRequestLength(data) <= data_len; |
+// Returns the command code (the first byte of the data) if it exists, or -1 if |
+// the data is empty. |
+unsigned int GetCommandCode(const std::string& data) { |
+ return data.empty() ? -1 : static_cast<unsigned int>(data[0]); |
} |
-// Returns true if the request data size is bigger than the threshold. |
-bool IsRequestTooLarge(const char* data, int data_len, int max_len) { |
- if (data_len < 4) |
- return false; |
- return GetRequestLength(data) > max_len; |
+// Creates a string of byte data from a ListValue of numbers. Returns true if |
+// all of the list elements are numbers. |
+bool ConvertListValueToString(base::ListValue* bytes, std::string* out) { |
+ out->clear(); |
+ |
+ unsigned int byte_count = bytes->GetSize(); |
+ if (byte_count != 0) { |
+ out->reserve(byte_count); |
+ for (unsigned int i = 0; i < byte_count; i++) { |
+ int value; |
+ if (!bytes->GetInteger(i, &value)) |
+ return false; |
+ out->push_back(static_cast<char>(value)); |
+ } |
+ } |
+ return true; |
} |
} // namespace |
@@ -129,22 +129,28 @@ void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) { |
LOG(ERROR) << "Invalid gnubby-auth control option"; |
} |
} else if (type == kDataMessage) { |
- int connection_id; |
- std::string json_message; |
- if (client_message->GetInteger(kConnectionId, &connection_id) && |
- client_message->GetString(kJSONMessage, &json_message)) { |
- ActiveSockets::iterator iter = active_sockets_.find(connection_id); |
- if (iter != active_sockets_.end()) { |
- HOST_LOG << "Sending gnubby response"; |
- |
- std::string response; |
- GetGnubbyResponseFromJson(json_message, &response); |
- iter->second->Send(response); |
+ ActiveSockets::iterator iter = GetSocketForMessage(client_message); |
+ if (iter != active_sockets_.end()) { |
+ base::ListValue* bytes; |
+ std::string response; |
+ if (client_message->GetList(kDataPayload, &bytes) && |
+ ConvertListValueToString(bytes, &response)) { |
+ HOST_LOG << "Sending gnubby response: " << GetCommandCode(response); |
+ iter->second->SendResponse(response); |
} else { |
- LOG(ERROR) << "Received gnubby-auth data for unknown connection"; |
+ LOG(ERROR) << "Invalid gnubby data"; |
+ SendErrorAndCloseActiveSocket(iter); |
} |
} else { |
- LOG(ERROR) << "Invalid gnubby-auth data message"; |
+ LOG(ERROR) << "Unknown gnubby-auth data connection"; |
+ } |
+ } else if (type == kErrorMessage) { |
+ ActiveSockets::iterator iter = GetSocketForMessage(client_message); |
+ if (iter != active_sockets_.end()) { |
+ HOST_LOG << "Sending gnubby error"; |
+ SendErrorAndCloseActiveSocket(iter); |
+ } else { |
+ LOG(ERROR) << "Unknown gnubby-auth error connection"; |
} |
} else { |
LOG(ERROR) << "Unknown gnubby-auth message type: " << type; |
@@ -152,15 +158,20 @@ void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) { |
} |
} |
-void GnubbyAuthHandlerPosix::DeliverHostDataMessage(int connection_id, |
- const std::string& data) |
- const { |
+void GnubbyAuthHandlerPosix::DeliverHostDataMessage( |
+ int connection_id, |
+ const std::string& data) const { |
DCHECK(CalledOnValidThread()); |
base::DictionaryValue request; |
request.SetString(kMessageType, kDataMessage); |
request.SetInteger(kConnectionId, connection_id); |
- request.SetString(kJSONMessage, data); |
+ |
+ base::ListValue* bytes = new base::ListValue(); |
+ for (std::string::const_iterator i = data.begin(); i != data.end(); ++i) { |
+ bytes->AppendInteger(static_cast<unsigned char>(*i)); |
+ } |
+ request.Set(kDataPayload, bytes); |
std::string request_json; |
if (!base::JSONWriter::Write(&request, &request_json)) { |
@@ -182,12 +193,31 @@ bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting( |
CompareSocket(socket)) != active_sockets_.end(); |
} |
+int GnubbyAuthHandlerPosix::GetConnectionIdForTesting( |
+ net::StreamListenSocket* socket) const { |
+ ActiveSockets::const_iterator iter = std::find_if( |
+ active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
+ return iter->first; |
+} |
+ |
+GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting( |
+ net::StreamListenSocket* socket) const { |
+ ActiveSockets::const_iterator iter = std::find_if( |
+ active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
+ return iter->second; |
+} |
+ |
void GnubbyAuthHandlerPosix::DidAccept( |
net::StreamListenSocket* server, |
scoped_ptr<net::StreamListenSocket> socket) { |
DCHECK(CalledOnValidThread()); |
- active_sockets_[++last_connection_id_] = socket.release(); |
+ int connection_id = ++last_connection_id_; |
+ active_sockets_[connection_id] = |
+ new GnubbySocket(socket.Pass(), |
+ base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut, |
+ base::Unretained(this), |
+ connection_id)); |
} |
void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket, |
@@ -195,39 +225,20 @@ void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket, |
int len) { |
DCHECK(CalledOnValidThread()); |
- ActiveSockets::iterator socket_iter = std::find_if( |
+ ActiveSockets::iterator iter = std::find_if( |
active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
- if (socket_iter != active_sockets_.end()) { |
- int connection_id = socket_iter->first; |
- |
- ActiveRequests::iterator request_iter = |
- active_requests_.find(connection_id); |
- if (request_iter != active_requests_.end()) { |
- std::vector<char>& saved_vector = request_iter->second; |
- if (IsRequestTooLarge( |
- saved_vector.data(), saved_vector.size(), kMaxRequestLength)) { |
- // We can't close a StreamListenSocket; throw away everything but the |
- // size bytes. |
- saved_vector.resize(4); |
- return; |
- } |
- saved_vector.insert(saved_vector.end(), data, data + len); |
- |
- if (IsRequestComplete(saved_vector.data(), saved_vector.size())) { |
- ProcessGnubbyRequest( |
- connection_id, saved_vector.data(), saved_vector.size()); |
- active_requests_.erase(request_iter); |
- } |
- } else if (IsRequestComplete(data, len)) { |
- ProcessGnubbyRequest(connection_id, data, len); |
- } else { |
- if (IsRequestTooLarge(data, len, kMaxRequestLength)) { |
- // Only save the size bytes. |
- active_requests_[connection_id] = std::vector<char>(data, data + 4); |
- } else { |
- active_requests_[connection_id] = std::vector<char>(data, data + len); |
- } |
+ if (iter != active_sockets_.end()) { |
+ GnubbySocket* gnubby_socket = iter->second; |
+ gnubby_socket->AddRequestData(data, len); |
+ if (gnubby_socket->IsRequestTooLarge()) { |
+ SendErrorAndCloseActiveSocket(iter); |
+ } else if (gnubby_socket->IsRequestComplete()) { |
+ std::string request_data; |
+ gnubby_socket->GetAndClearRequestData(&request_data); |
+ ProcessGnubbyRequest(iter->first, request_data); |
} |
+ } else { |
+ LOG(ERROR) << "Received data for unknown connection"; |
} |
} |
@@ -237,8 +248,6 @@ void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) { |
ActiveSockets::iterator iter = std::find_if( |
active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
if (iter != active_sockets_.end()) { |
- active_requests_.erase(iter->first); |
- |
delete iter->second; |
active_sockets_.erase(iter); |
} |
@@ -264,16 +273,35 @@ void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() { |
} |
} |
-void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(int connection_id, |
- const char* data, |
- int data_len) { |
- std::string json; |
- if (GetJsonFromGnubbyRequest(data, data_len, &json)) { |
- HOST_LOG << "Received gnubby request"; |
- DeliverHostDataMessage(connection_id, json); |
- } else { |
- LOG(ERROR) << "Could not decode gnubby request"; |
+void GnubbyAuthHandlerPosix::ProcessGnubbyRequest( |
+ int connection_id, |
+ const std::string& request_data) { |
+ HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data); |
+ DeliverHostDataMessage(connection_id, request_data); |
+} |
+ |
+GnubbyAuthHandlerPosix::ActiveSockets::iterator |
+GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue* message) { |
+ int connection_id; |
+ if (message->GetInteger(kConnectionId, &connection_id)) { |
+ return active_sockets_.find(connection_id); |
} |
+ return active_sockets_.end(); |
+} |
+ |
+void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket( |
+ const ActiveSockets::iterator& iter) { |
+ iter->second->SendSshError(); |
+ |
+ delete iter->second; |
+ active_sockets_.erase(iter); |
+} |
+ |
+void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) { |
+ HOST_LOG << "Gnubby request timed out"; |
+ ActiveSockets::iterator iter = active_sockets_.find(connection_id); |
+ if (iter != active_sockets_.end()) |
+ SendErrorAndCloseActiveSocket(iter); |
} |
} // namespace remoting |