Index: remoting/host/security_key/gnubby_auth_handler_win.cc |
diff --git a/remoting/host/security_key/gnubby_auth_handler_win.cc b/remoting/host/security_key/gnubby_auth_handler_win.cc |
index fa7a0a0805022951ccae83ebe604128075c5ea4b..98570e14d9fbec8cd5b3a40d32cdb2c951866b10 100644 |
--- a/remoting/host/security_key/gnubby_auth_handler_win.cc |
+++ b/remoting/host/security_key/gnubby_auth_handler_win.cc |
@@ -9,10 +9,13 @@ |
#include <string> |
#include "base/bind.h" |
+#include "base/location.h" |
#include "base/logging.h" |
+#include "base/memory/weak_ptr.h" |
#include "base/strings/stringprintf.h" |
#include "base/strings/utf_string_conversions.h" |
#include "base/threading/thread_checker.h" |
+#include "base/threading/thread_task_runner_handle.h" |
#include "base/time/time.h" |
#include "base/timer/timer.h" |
#include "base/win/win_util.h" |
@@ -69,6 +72,7 @@ class GnubbyAuthHandlerWin : public GnubbyAuthHandler, public IPC::Listener { |
const std::string& response) override; |
void SendErrorAndCloseConnection(int gnubby_connection_id) override; |
void SetSendMessageCallback(const SendMessageCallback& callback) override; |
+ void SetSessionIdCallback(const SessionIdCallback& callback) override; |
size_t GetActiveConnectionCountForTest() const override; |
void SetRequestTimeoutForTest(base::TimeDelta timeout) override; |
@@ -100,6 +104,9 @@ class GnubbyAuthHandlerWin : public GnubbyAuthHandler, public IPC::Listener { |
// Sends a gnubby extension messages to the remote client when called. |
SendMessageCallback send_message_callback_; |
+ // Used to retrieve the id of the remoted session. |
+ SessionIdCallback session_id_callback_; |
+ |
// Tracks the IPC channel created for each security key forwarding session. |
ActiveChannels active_channels_; |
@@ -117,19 +124,24 @@ class GnubbyAuthHandlerWin : public GnubbyAuthHandler, public IPC::Listener { |
// Ensures GnubbyAuthHandlerWin methods are called on the same thread. |
base::ThreadChecker thread_checker_; |
+ base::WeakPtrFactory<GnubbyAuthHandlerWin> weak_factory_; |
+ |
DISALLOW_COPY_AND_ASSIGN(GnubbyAuthHandlerWin); |
}; |
std::unique_ptr<GnubbyAuthHandler> GnubbyAuthHandler::Create( |
- const SendMessageCallback& callback) { |
+ const SendMessageCallback& send_message_callback, |
+ const SessionIdCallback& session_id_callback) { |
std::unique_ptr<GnubbyAuthHandler> auth_handler(new GnubbyAuthHandlerWin()); |
- auth_handler->SetSendMessageCallback(callback); |
+ auth_handler->SetSendMessageCallback(send_message_callback); |
+ auth_handler->SetSessionIdCallback(session_id_callback); |
return auth_handler; |
} |
GnubbyAuthHandlerWin::GnubbyAuthHandlerWin() |
: disconnect_timeout_( |
- base::TimeDelta::FromSeconds(kInitialRequestTimeoutSeconds)) {} |
+ base::TimeDelta::FromSeconds(kInitialRequestTimeoutSeconds)), |
+ weak_factory_(this) {} |
GnubbyAuthHandlerWin::~GnubbyAuthHandlerWin() {} |
@@ -173,6 +185,12 @@ void GnubbyAuthHandlerWin::SetSendMessageCallback( |
send_message_callback_ = callback; |
} |
+void GnubbyAuthHandlerWin::SetSessionIdCallback( |
+ const SessionIdCallback& callback) { |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ session_id_callback_ = callback; |
+} |
+ |
size_t GnubbyAuthHandlerWin::GetActiveConnectionCountForTest() const { |
return active_channels_.size(); |
} |
@@ -239,14 +257,30 @@ void GnubbyAuthHandlerWin::OnChannelConnected(int32_t peer_pid) { |
base::Bind(&GnubbyAuthHandlerWin::OnChannelError, |
base::Unretained(this))); |
- // TODO(joedow): Use |peer_pid| to determine the originating session |
- // using ProcessIdToSessionId() and verify it is the one we created. |
- // Tracked via crbug.com/591746 |
+ // Verify the IPC connection attempt originated from the session we are |
+ // currently remoting. We don't want to service requests from arbitrary |
+ // Windows sessions. |
+ DWORD peer_session_id; |
+ if (!ProcessIdToSessionId(peer_pid, &peer_session_id)) { |
+ PLOG(ERROR) << "ProcessIdToSessionId() failed"; |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(&GnubbyAuthHandlerWin::OnChannelError, |
+ weak_factory_.GetWeakPtr())); |
+ return; |
+ } |
+ if (peer_session_id != session_id_callback_.Run()) { |
+ LOG(INFO) << "Ignoring connection attempt from outside remoted session."; |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(&GnubbyAuthHandlerWin::OnChannelError, |
+ weak_factory_.GetWeakPtr())); |
+ return; |
+ } |
int new_connection_id = ++last_connection_id_; |
std::unique_ptr<RemoteSecurityKeyIpcServer> ipc_server( |
RemoteSecurityKeyIpcServer::Create( |
- new_connection_id, disconnect_timeout_, send_message_callback_, |
+ new_connection_id, peer_session_id, disconnect_timeout_, |
+ send_message_callback_, |
base::Bind(&GnubbyAuthHandlerWin::CloseSecurityKeyRequestIpcChannel, |
base::Unretained(this), new_connection_id))); |