| Index: remoting/host/gnubby_auth_handler_posix.cc
|
| diff --git a/remoting/host/gnubby_auth_handler_posix.cc b/remoting/host/gnubby_auth_handler_posix.cc
|
| index c797cce46f4eb31854b9c2bbf8cf47d25b61d092..77938e078b70c920a03c21e1ab3532d4cfe020fd 100644
|
| --- a/remoting/host/gnubby_auth_handler_posix.cc
|
| +++ b/remoting/host/gnubby_auth_handler_posix.cc
|
| @@ -6,7 +6,6 @@
|
|
|
| #include <unistd.h>
|
| #include <utility>
|
| -#include <vector>
|
|
|
| #include "base/bind.h"
|
| #include "base/file_util.h"
|
| @@ -17,7 +16,7 @@
|
| #include "base/values.h"
|
| #include "net/socket/unix_domain_socket_posix.h"
|
| #include "remoting/base/logging.h"
|
| -#include "remoting/host/gnubby_util.h"
|
| +#include "remoting/host/gnubby_socket.h"
|
| #include "remoting/proto/control.pb.h"
|
| #include "remoting/protocol/client_stub.h"
|
|
|
| @@ -25,15 +24,14 @@ namespace remoting {
|
|
|
| namespace {
|
|
|
| -const int kMaxRequestLength = 4096;
|
| -
|
| const char kConnectionId[] = "connectionId";
|
| const char kControlMessage[] = "control";
|
| const char kControlOption[] = "option";
|
| const char kDataMessage[] = "data";
|
| +const char kDataPayload[] = "data";
|
| +const char kErrorMessage[] = "error";
|
| const char kGnubbyAuthMessage[] = "gnubby-auth";
|
| const char kGnubbyAuthV1[] = "auth-v1";
|
| -const char kJSONMessage[] = "jsonMessage";
|
| const char kMessageType[] = "type";
|
|
|
| // The name of the socket to listen for gnubby requests on.
|
| @@ -45,9 +43,8 @@ class CompareSocket {
|
| public:
|
| explicit CompareSocket(net::StreamListenSocket* socket) : socket_(socket) {}
|
|
|
| - bool operator()(const std::pair<int, net::StreamListenSocket*> element)
|
| - const {
|
| - return socket_ == element.second;
|
| + bool operator()(const std::pair<int, GnubbySocket*> element) const {
|
| + return element.second->IsSocket(socket_);
|
| }
|
|
|
| private:
|
| @@ -63,25 +60,28 @@ bool MatchUid(uid_t user_id, gid_t) {
|
| return allowed;
|
| }
|
|
|
| -// Returns the request data length from the first four data bytes.
|
| -int GetRequestLength(const char* data) {
|
| - return ((data[0] & 255) << 24) + ((data[1] & 255) << 16) +
|
| - ((data[2] & 255) << 8) + (data[3] & 255) + 4;
|
| -}
|
| -
|
| -// Returns true if the request data is complete (has at least as many bytes as
|
| -// indicated by the size in the first four bytes plus four for the first bytes).
|
| -bool IsRequestComplete(const char* data, int data_len) {
|
| - if (data_len < 4)
|
| - return false;
|
| - return GetRequestLength(data) <= data_len;
|
| +// Returns the command code (the first byte of the data) if it exists, or -1 if
|
| +// the data is empty.
|
| +unsigned int GetCommandCode(const std::string& data) {
|
| + return data.empty() ? -1 : static_cast<unsigned int>(data[0]);
|
| }
|
|
|
| -// Returns true if the request data size is bigger than the threshold.
|
| -bool IsRequestTooLarge(const char* data, int data_len, int max_len) {
|
| - if (data_len < 4)
|
| - return false;
|
| - return GetRequestLength(data) > max_len;
|
| +// Creates a string of byte data from a ListValue of numbers. Returns true if
|
| +// all of the list elements are numbers.
|
| +bool ConvertListValueToString(base::ListValue* bytes, std::string* out) {
|
| + out->clear();
|
| +
|
| + unsigned int byte_count = bytes->GetSize();
|
| + if (byte_count != 0) {
|
| + out->reserve(byte_count);
|
| + for (unsigned int i = 0; i < byte_count; i++) {
|
| + int value;
|
| + if (!bytes->GetInteger(i, &value))
|
| + return false;
|
| + out->push_back(static_cast<char>(value));
|
| + }
|
| + }
|
| + return true;
|
| }
|
|
|
| } // namespace
|
| @@ -129,22 +129,28 @@ void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) {
|
| LOG(ERROR) << "Invalid gnubby-auth control option";
|
| }
|
| } else if (type == kDataMessage) {
|
| - int connection_id;
|
| - std::string json_message;
|
| - if (client_message->GetInteger(kConnectionId, &connection_id) &&
|
| - client_message->GetString(kJSONMessage, &json_message)) {
|
| - ActiveSockets::iterator iter = active_sockets_.find(connection_id);
|
| - if (iter != active_sockets_.end()) {
|
| - HOST_LOG << "Sending gnubby response";
|
| -
|
| - std::string response;
|
| - GetGnubbyResponseFromJson(json_message, &response);
|
| - iter->second->Send(response);
|
| + ActiveSockets::iterator iter = GetSocketForMessage(client_message);
|
| + if (iter != active_sockets_.end()) {
|
| + base::ListValue* bytes;
|
| + std::string response;
|
| + if (client_message->GetList(kDataPayload, &bytes) &&
|
| + ConvertListValueToString(bytes, &response)) {
|
| + HOST_LOG << "Sending gnubby response: " << GetCommandCode(response);
|
| + iter->second->SendResponse(response);
|
| } else {
|
| - LOG(ERROR) << "Received gnubby-auth data for unknown connection";
|
| + LOG(ERROR) << "Invalid gnubby data";
|
| + SendErrorAndCloseActiveSocket(iter);
|
| }
|
| } else {
|
| - LOG(ERROR) << "Invalid gnubby-auth data message";
|
| + LOG(ERROR) << "Unknown gnubby-auth data connection";
|
| + }
|
| + } else if (type == kErrorMessage) {
|
| + ActiveSockets::iterator iter = GetSocketForMessage(client_message);
|
| + if (iter != active_sockets_.end()) {
|
| + HOST_LOG << "Sending gnubby error";
|
| + SendErrorAndCloseActiveSocket(iter);
|
| + } else {
|
| + LOG(ERROR) << "Unknown gnubby-auth error connection";
|
| }
|
| } else {
|
| LOG(ERROR) << "Unknown gnubby-auth message type: " << type;
|
| @@ -152,15 +158,20 @@ void GnubbyAuthHandlerPosix::DeliverClientMessage(const std::string& message) {
|
| }
|
| }
|
|
|
| -void GnubbyAuthHandlerPosix::DeliverHostDataMessage(int connection_id,
|
| - const std::string& data)
|
| - const {
|
| +void GnubbyAuthHandlerPosix::DeliverHostDataMessage(
|
| + int connection_id,
|
| + const std::string& data) const {
|
| DCHECK(CalledOnValidThread());
|
|
|
| base::DictionaryValue request;
|
| request.SetString(kMessageType, kDataMessage);
|
| request.SetInteger(kConnectionId, connection_id);
|
| - request.SetString(kJSONMessage, data);
|
| +
|
| + base::ListValue* bytes = new base::ListValue();
|
| + for (std::string::const_iterator i = data.begin(); i != data.end(); ++i) {
|
| + bytes->AppendInteger(static_cast<unsigned char>(*i));
|
| + }
|
| + request.Set(kDataPayload, bytes);
|
|
|
| std::string request_json;
|
| if (!base::JSONWriter::Write(&request, &request_json)) {
|
| @@ -182,12 +193,31 @@ bool GnubbyAuthHandlerPosix::HasActiveSocketForTesting(
|
| CompareSocket(socket)) != active_sockets_.end();
|
| }
|
|
|
| +int GnubbyAuthHandlerPosix::GetConnectionIdForTesting(
|
| + net::StreamListenSocket* socket) const {
|
| + ActiveSockets::const_iterator iter = std::find_if(
|
| + active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
|
| + return iter->first;
|
| +}
|
| +
|
| +GnubbySocket* GnubbyAuthHandlerPosix::GetGnubbySocketForTesting(
|
| + net::StreamListenSocket* socket) const {
|
| + ActiveSockets::const_iterator iter = std::find_if(
|
| + active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
|
| + return iter->second;
|
| +}
|
| +
|
| void GnubbyAuthHandlerPosix::DidAccept(
|
| net::StreamListenSocket* server,
|
| scoped_ptr<net::StreamListenSocket> socket) {
|
| DCHECK(CalledOnValidThread());
|
|
|
| - active_sockets_[++last_connection_id_] = socket.release();
|
| + int connection_id = ++last_connection_id_;
|
| + active_sockets_[connection_id] =
|
| + new GnubbySocket(socket.Pass(),
|
| + base::Bind(&GnubbyAuthHandlerPosix::RequestTimedOut,
|
| + base::Unretained(this),
|
| + connection_id));
|
| }
|
|
|
| void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket,
|
| @@ -195,39 +225,20 @@ void GnubbyAuthHandlerPosix::DidRead(net::StreamListenSocket* socket,
|
| int len) {
|
| DCHECK(CalledOnValidThread());
|
|
|
| - ActiveSockets::iterator socket_iter = std::find_if(
|
| + ActiveSockets::iterator iter = std::find_if(
|
| active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
|
| - if (socket_iter != active_sockets_.end()) {
|
| - int connection_id = socket_iter->first;
|
| -
|
| - ActiveRequests::iterator request_iter =
|
| - active_requests_.find(connection_id);
|
| - if (request_iter != active_requests_.end()) {
|
| - std::vector<char>& saved_vector = request_iter->second;
|
| - if (IsRequestTooLarge(
|
| - saved_vector.data(), saved_vector.size(), kMaxRequestLength)) {
|
| - // We can't close a StreamListenSocket; throw away everything but the
|
| - // size bytes.
|
| - saved_vector.resize(4);
|
| - return;
|
| - }
|
| - saved_vector.insert(saved_vector.end(), data, data + len);
|
| -
|
| - if (IsRequestComplete(saved_vector.data(), saved_vector.size())) {
|
| - ProcessGnubbyRequest(
|
| - connection_id, saved_vector.data(), saved_vector.size());
|
| - active_requests_.erase(request_iter);
|
| - }
|
| - } else if (IsRequestComplete(data, len)) {
|
| - ProcessGnubbyRequest(connection_id, data, len);
|
| - } else {
|
| - if (IsRequestTooLarge(data, len, kMaxRequestLength)) {
|
| - // Only save the size bytes.
|
| - active_requests_[connection_id] = std::vector<char>(data, data + 4);
|
| - } else {
|
| - active_requests_[connection_id] = std::vector<char>(data, data + len);
|
| - }
|
| + if (iter != active_sockets_.end()) {
|
| + GnubbySocket* gnubby_socket = iter->second;
|
| + gnubby_socket->AddRequestData(data, len);
|
| + if (gnubby_socket->IsRequestTooLarge()) {
|
| + SendErrorAndCloseActiveSocket(iter);
|
| + } else if (gnubby_socket->IsRequestComplete()) {
|
| + std::string request_data;
|
| + gnubby_socket->GetAndClearRequestData(&request_data);
|
| + ProcessGnubbyRequest(iter->first, request_data);
|
| }
|
| + } else {
|
| + LOG(ERROR) << "Received data for unknown connection";
|
| }
|
| }
|
|
|
| @@ -237,8 +248,6 @@ void GnubbyAuthHandlerPosix::DidClose(net::StreamListenSocket* socket) {
|
| ActiveSockets::iterator iter = std::find_if(
|
| active_sockets_.begin(), active_sockets_.end(), CompareSocket(socket));
|
| if (iter != active_sockets_.end()) {
|
| - active_requests_.erase(iter->first);
|
| -
|
| delete iter->second;
|
| active_sockets_.erase(iter);
|
| }
|
| @@ -264,16 +273,35 @@ void GnubbyAuthHandlerPosix::CreateAuthorizationSocket() {
|
| }
|
| }
|
|
|
| -void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(int connection_id,
|
| - const char* data,
|
| - int data_len) {
|
| - std::string json;
|
| - if (GetJsonFromGnubbyRequest(data, data_len, &json)) {
|
| - HOST_LOG << "Received gnubby request";
|
| - DeliverHostDataMessage(connection_id, json);
|
| - } else {
|
| - LOG(ERROR) << "Could not decode gnubby request";
|
| +void GnubbyAuthHandlerPosix::ProcessGnubbyRequest(
|
| + int connection_id,
|
| + const std::string& request_data) {
|
| + HOST_LOG << "Received gnubby request: " << GetCommandCode(request_data);
|
| + DeliverHostDataMessage(connection_id, request_data);
|
| +}
|
| +
|
| +GnubbyAuthHandlerPosix::ActiveSockets::iterator
|
| +GnubbyAuthHandlerPosix::GetSocketForMessage(base::DictionaryValue* message) {
|
| + int connection_id;
|
| + if (message->GetInteger(kConnectionId, &connection_id)) {
|
| + return active_sockets_.find(connection_id);
|
| }
|
| + return active_sockets_.end();
|
| +}
|
| +
|
| +void GnubbyAuthHandlerPosix::SendErrorAndCloseActiveSocket(
|
| + const ActiveSockets::iterator& iter) {
|
| + iter->second->SendSshError();
|
| +
|
| + delete iter->second;
|
| + active_sockets_.erase(iter);
|
| +}
|
| +
|
| +void GnubbyAuthHandlerPosix::RequestTimedOut(int connection_id) {
|
| + HOST_LOG << "Gnubby request timed out";
|
| + ActiveSockets::iterator iter = active_sockets_.find(connection_id);
|
| + if (iter != active_sockets_.end())
|
| + SendErrorAndCloseActiveSocket(iter);
|
| }
|
|
|
| } // namespace remoting
|
|
|