Index: remoting/host/gnubby_auth_handler_posix.cc |
diff --git a/remoting/host/gnubby_auth_handler_posix.cc b/remoting/host/gnubby_auth_handler_posix.cc |
index c797cce46f4eb31854b9c2bbf8cf47d25b61d092..9df548c114e527a68798b761d734167f177ee598 100644 |
--- a/remoting/host/gnubby_auth_handler_posix.cc |
+++ b/remoting/host/gnubby_auth_handler_posix.cc |
@@ -6,8 +6,8 @@ |
#include <unistd.h> |
#include <utility> |
-#include <vector> |
+#include "base/base64.h" |
#include "base/bind.h" |
#include "base/file_util.h" |
#include "base/json/json_reader.h" |
@@ -17,7 +17,7 @@ |
#include "base/values.h" |
#include "net/socket/unix_domain_socket_posix.h" |
#include "remoting/base/logging.h" |
-#include "remoting/host/gnubby_util.h" |
+#include "remoting/host/gnubby_socket.h" |
#include "remoting/proto/control.pb.h" |
#include "remoting/protocol/client_stub.h" |
@@ -25,15 +25,14 @@ namespace remoting { |
namespace { |
-const int kMaxRequestLength = 4096; |
- |
+const char kBase64Data[] = "base64Data"; |
const char kConnectionId[] = "connectionId"; |
const char kControlMessage[] = "control"; |
const char kControlOption[] = "option"; |
const char kDataMessage[] = "data"; |
+const char kErrorMessage[] = "error"; |
const char kGnubbyAuthMessage[] = "gnubby-auth"; |
const char kGnubbyAuthV1[] = "auth-v1"; |
-const char kJSONMessage[] = "jsonMessage"; |
const char kMessageType[] = "type"; |
// The name of the socket to listen for gnubby requests on. |
@@ -45,9 +44,8 @@ class CompareSocket { |
public: |
explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {} |
- bool operator()(const std::pair<int, net::StreamListenSocket*> element) |
- const { |
- return socket_ == element.second; |
+ bool operator()(const std::pair<int, GnubbySocket*> element) const { |
+ return element.second->IsSocket(socket_); |
} |
private: |
@@ -63,25 +61,8 @@ bool MatchUid(uid_t user_id, gid_t) { |
return allowed; |
} |
-// Returns the request data length from the first four data bytes. |
-int GetRequestLength(const char* data) { |
- return ((data[0] & 255) << 24) + ((data[1] & 255) << 16) + |
- ((data[2] & 255) << 8) + (data[3] & 255) + 4; |
-} |
- |
-// Returns true if the request data is complete (has at least as many bytes as |
-// indicated by the size in the first four bytes plus four for the first bytes). |
-bool IsRequestComplete(const char* data, int data_len) { |
- if (data_len < 4) |
- return false; |
- return GetRequestLength(data) <= data_len; |
-} |
- |
-// Returns true if the request data size is bigger than the threshold. |
-bool IsRequestTooLarge(const char* data, int data_len, int max_len) { |
- if (data_len < 4) |
- return false; |
- return GetRequestLength(data) > max_len; |
+int GetCommandCode(const std::string& data) { |
+ return data.empty() ? -1 : static_cast<int>(data[0]); |
} |
} // namespace |
@@ -129,22 +110,28 @@ void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) { |
LOG(ERROR) << "Invalid gnubby-auth control option"; |
} |
} else if (type == kDataMessage) { |
- int connection_id; |
- std::string json_message; |
- if (client_message->GetInteger(kConnectionId, &connection_id) && |
- client_message->GetString(kJSONMessage, &json_message)) { |
- ActiveSockets::iterator iter = active_sockets_.find(connection_id); |
- if (iter != active_sockets_.end()) { |
- HOST_LOG << "Sending gnubby response"; |
- |
- std::string response; |
- GetGnubbyResponseFromJson(json_message, &response); |
- iter->second->Send(response); |
+ ActiveSockets::iterator iter = GetActiveSocket(client_message); |
+ if (iter != active_sockets_.end()) { |
+ std::string base64_data; |
+ std::string response; |
+ if (client_message->GetString(kBase64Data, &base64_data) && |
+ base::Base64Decode(base64_data, &response)) { |
+ HOST_LOG << "Sending gnubby response: " << GetCommandCode(response); |
+ iter->second->SendResponse(response); |
} else { |
- LOG(ERROR) << "Received gnubby-auth data for unknown connection"; |
+ LOG(ERROR) << "Invalid gnubby data"; |
+ SendErrorAndCloseActiveSocket(iter); |
} |
} else { |
- LOG(ERROR) << "Invalid gnubby-auth data message"; |
+ LOG(ERROR) << "Unknown gnubby-auth data connection"; |
+ } |
+ } else if (type == kErrorMessage) { |
+ ActiveSockets::iterator iter = GetActiveSocket(client_message); |
+ if (iter != active_sockets_.end()) { |
+ HOST_LOG << "Sending gnubby error"; |
+ SendErrorAndCloseActiveSocket(iter); |
+ } else { |
+ LOG(ERROR) << "Unknown gnubby-auth error connection"; |
} |
} else { |
LOG(ERROR) << "Unknown gnubby-auth message type: " << type; |
@@ -152,15 +139,15 @@ void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) { |
} |
} |
-void GnubbyAuthHandlerPosix::DeliverHostDataMessage(int connection_id, |
- const std::string& data) |
- const { |
+void GnubbyAuthHandlerPosix::DeliverHostDataMessage( |
+ int connection_id, |
+ const std::string& base64_data) const { |
DCHECK(CalledOnValidThread()); |
base::DictionaryValue request; |
request.SetString(kMessageType, kDataMessage); |
request.SetInteger(kConnectionId, connection_id); |
- request.SetString(kJSONMessage, data); |
+ request.SetString(kBase64Data, base64_data); |
std::string request_json; |
if (!base::JSONWriter::Write(&request, &request_json)) { |
@@ -182,12 +169,31 @@ bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting( |
CompareSocket(socket)) != active_sockets_.end(); |
} |
+int GnubbyAuthHandlerPosix::GetConnectionIdForTesting( |
+ net::StreamListenSocket* socket) const { |
+ ActiveSockets::const_iterator iter = std::find_if( |
+ active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
+ return iter->first; |
+} |
+ |
+GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting( |
+ net::StreamListenSocket* socket) const { |
+ ActiveSockets::const_iterator iter = std::find_if( |
+ active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
+ return iter->second; |
+} |
+ |
void GnubbyAuthHandlerPosix::DidAccept( |
net::StreamListenSocket* server, |
scoped_ptr<net::StreamListenSocket> socket) { |
DCHECK(CalledOnValidThread()); |
- active_sockets_[++last_connection_id_] = socket.release(); |
+ int connection_id = ++last_connection_id_; |
+ active_sockets_[connection_id] = |
+ new GnubbySocket(socket.Pass(), |
+ base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut, |
+ base::Unretained(this), |
+ connection_id)); |
} |
void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket, |
@@ -195,39 +201,20 @@ void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket, |
int len) { |
DCHECK(CalledOnValidThread()); |
- ActiveSockets::iterator socket_iter = std::find_if( |
+ ActiveSockets::iterator iter = std::find_if( |
active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
- if (socket_iter != active_sockets_.end()) { |
- int connection_id = socket_iter->first; |
- |
- ActiveRequests::iterator request_iter = |
- active_requests_.find(connection_id); |
- if (request_iter != active_requests_.end()) { |
- std::vector<char>& saved_vector = request_iter->second; |
- if (IsRequestTooLarge( |
- saved_vector.data(), saved_vector.size(), kMaxRequestLength)) { |
- // We can't close a StreamListenSocket; throw away everything but the |
- // size bytes. |
- saved_vector.resize(4); |
- return; |
- } |
- saved_vector.insert(saved_vector.end(), data, data + len); |
- |
- if (IsRequestComplete(saved_vector.data(), saved_vector.size())) { |
- ProcessGnubbyRequest( |
- connection_id, saved_vector.data(), saved_vector.size()); |
- active_requests_.erase(request_iter); |
- } |
- } else if (IsRequestComplete(data, len)) { |
- ProcessGnubbyRequest(connection_id, data, len); |
- } else { |
- if (IsRequestTooLarge(data, len, kMaxRequestLength)) { |
- // Only save the size bytes. |
- active_requests_[connection_id] = std::vector<char>(data, data + 4); |
- } else { |
- active_requests_[connection_id] = std::vector<char>(data, data + len); |
- } |
+ if (iter != active_sockets_.end()) { |
+ GnubbySocket* gnubby_socket = iter->second; |
+ gnubby_socket->AddRequestData(data, len); |
+ if (gnubby_socket->IsRequestTooLarge()) { |
+ SendErrorAndCloseActiveSocket(iter); |
+ } else if (gnubby_socket->IsRequestComplete()) { |
+ std::string request_data; |
+ gnubby_socket->GetAndClearRequestData(&request_data); |
+ ProcessGnubbyRequest(iter->first, request_data); |
} |
+ } else { |
+ LOG(ERROR) << "Received data for unknown connection"; |
} |
} |
@@ -237,8 +224,6 @@ void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) { |
ActiveSockets::iterator iter = std::find_if( |
active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket)); |
if (iter != active_sockets_.end()) { |
- active_requests_.erase(iter->first); |
- |
delete iter->second; |
active_sockets_.erase(iter); |
} |
@@ -264,16 +249,37 @@ void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() { |
} |
} |
-void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(int connection_id, |
- const char* data, |
- int data_len) { |
- std::string json; |
- if (GetJsonFromGnubbyRequest(data, data_len, &json)) { |
- HOST_LOG << "Received gnubby request"; |
- DeliverHostDataMessage(connection_id, json); |
- } else { |
- LOG(ERROR) << "Could not decode gnubby request"; |
+void GnubbyAuthHandlerPosix::ProcessGnubbyRequest( |
+ int connection_id, |
+ const std::string& request_data) { |
+ HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data); |
+ std::string base64_request; |
+ base::Base64Encode(request_data, &base64_request); |
+ DeliverHostDataMessage(connection_id, base64_request); |
+} |
+ |
+GnubbyAuthHandlerPosix::ActiveSockets::iterator |
+GnubbyAuthHandlerPosix::GetActiveSocket(base::DictionaryValue* dictionary) { |
Sergey Ulanov
2014/03/21 02:09:37
Maybe GetSocketForMessage()? Also rename dictionar
psj
2014/03/21 21:30:45
Done.
|
+ int connection_id; |
+ if (dictionary->GetInteger(kConnectionId, &connection_id)) { |
+ return active_sockets_.find(connection_id); |
} |
+ return active_sockets_.end(); |
+} |
+ |
+void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket( |
+ const ActiveSockets::iterator& iter) { |
+ if (iter != active_sockets_.end()) { |
+ iter->second->SendSshError(); |
+ |
+ delete iter->second; |
+ active_sockets_.erase(iter); |
+ } |
+} |
+ |
+void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) { |
+ LOG(ERROR) << "Request timed out"; |
Sergey Ulanov
2014/03/21 02:09:37
HOST_LOG please. Also update the message to indica
psj
2014/03/21 21:30:45
Done.
|
+ SendErrorAndCloseActiveSocket(active_sockets_.find(connection_id)); |
} |
} // namespace remoting |