Index: remoting/host/security_key/security_key_auth_handler_linux.cc |
diff --git a/remoting/host/security_key/security_key_auth_handler_linux.cc b/remoting/host/security_key/security_key_auth_handler_linux.cc |
index 346a2febb396c3ed14d25a80041e02610ca61be1..b938d0f5522bf3ff301d56578f13bdac5d971f25 100644 |
--- a/remoting/host/security_key/security_key_auth_handler_linux.cc |
+++ b/remoting/host/security_key/security_key_auth_handler_linux.cc |
@@ -4,18 +4,29 @@ |
#include "remoting/host/security_key/security_key_auth_handler.h" |
-#include <stdint.h> |
#include <unistd.h> |
+#include <cstdint> |
#include <memory> |
#include "base/bind.h" |
+#include "base/callback.h" |
#include "base/files/file_util.h" |
#include "base/lazy_instance.h" |
+#include "base/location.h" |
#include "base/logging.h" |
-#include "base/stl_util.h" |
+#include "base/memory/ptr_util.h" |
+#include "base/memory/ref_counted.h" |
+#include "base/memory/weak_ptr.h" |
+#include "base/single_thread_task_runner.h" |
+#include "base/synchronization/lock.h" |
+// TODO: DEBUG DEBUG DEBUG -> REMOVE LOGGING |
+#include "base/threading/platform_thread.h" |
+#include "base/threading/thread.h" |
+// END DEBUG DEBUG DEBUG |
#include "base/threading/thread_checker.h" |
#include "base/threading/thread_restrictions.h" |
+#include "base/threading/thread_task_runner_handle.h" |
#include "base/values.h" |
#include "net/base/completion_callback.h" |
#include "net/base/net_errors.h" |
@@ -53,11 +64,13 @@ namespace remoting { |
class SecurityKeyAuthHandlerLinux : public SecurityKeyAuthHandler { |
public: |
- SecurityKeyAuthHandlerLinux(); |
+ explicit SecurityKeyAuthHandlerLinux( |
+ scoped_refptr<base::SingleThreadTaskRunner> file_task_runner); |
~SecurityKeyAuthHandlerLinux() override; |
private: |
- typedef std::map<int, SecurityKeySocket*> ActiveSockets; |
+ // The actual implementation resides in Core class. |
+ class Core; |
// SecurityKeyAuthHandler interface. |
void CreateSecurityKeyConnection() override; |
@@ -69,6 +82,54 @@ class SecurityKeyAuthHandlerLinux : public SecurityKeyAuthHandler { |
size_t GetActiveConnectionCountForTest() const override; |
void SetRequestTimeoutForTest(base::TimeDelta timeout) override; |
+ // Ensures SecurityKeyAuthHandlerLinux methods are called on the same thread. |
+ base::ThreadChecker thread_checker_; |
+ |
+ // Used to pass security key extension messages to the client. |
+ SendMessageCallback send_message_callback_; |
+ |
+ // Timeout used for a request. |
+ base::TimeDelta request_timeout_; |
+ |
+ // Used by |core_| to perform File IO via Unix Domain Sockets. |
+ scoped_refptr<base::SingleThreadTaskRunner> file_task_runner_; |
+ |
+ // Responsible for handling security key requests on |file_task_runner_|. |
+ std::unique_ptr<Core> core_; |
+ |
+ DISALLOW_COPY_AND_ASSIGN(SecurityKeyAuthHandlerLinux); |
+}; |
+ |
+class SecurityKeyAuthHandlerLinux::Core { |
+ public: |
+ Core(const SendMessageCallback& send_message_callback, |
+ base::TimeDelta timeout); |
+ ~Core(); |
+ |
+ // Creates the accept socket, binds to it, and begins accepting listeners. |
+ // Must be called from the file IO task runner. |
+ void StartListening(); |
+ |
+ // Returns whether |security_key_connection_id| corresponds to an active |
+ // connection. Can be called from any thread. |
+ bool IsValidConnectionId(int security_key_connection_id) const; |
+ |
+ // Passes |response| to the socket associated with |
+ // |security_key_connection_id|. Must be called from the file IO task runner. |
+ void SendClientResponse(int security_key_connection_id, |
+ const std::string& response); |
+ |
+ // Sends an SSH error code to the socket assocated with |
+ // |security_key_connection_id| and closes it. Must be called from the file |
+ // IO task runner. |
+ void SendErrorAndCloseConnection(int security_key_connection_id); |
+ |
+ // Returns the number of active connections. Can be called from any thread. |
+ size_t GetActiveConnectionCountForTest() const; |
+ |
+ private: |
+ typedef std::map<int, std::unique_ptr<SecurityKeySocket>> ActiveSockets; |
+ |
// Starts listening for connection. |
void DoAccept(); |
@@ -101,22 +162,32 @@ class SecurityKeyAuthHandlerLinux : public SecurityKeyAuthHandler { |
SendMessageCallback send_message_callback_; |
// The last assigned security key connection id. |
- int last_connection_id_; |
+ int last_connection_id_ = 0; |
// Sockets by connection id used to process gnubbyd requests. |
ActiveSockets active_sockets_; |
+ // Protects access to |active_sockets_|. Marked as mutable so this lock can |
+ // be acquired inside methods marked 'const'. |
+ mutable base::Lock active_sockets_lock_; |
+ |
// Timeout used for a request. |
base::TimeDelta request_timeout_; |
- DISALLOW_COPY_AND_ASSIGN(SecurityKeyAuthHandlerLinux); |
+ // Used to run |send_message_callback_| on the main thread. |
+ scoped_refptr<base::SingleThreadTaskRunner> main_task_runner_; |
+ |
+ base::WeakPtrFactory<Core> weak_factory_; |
+ |
+ DISALLOW_COPY_AND_ASSIGN(Core); |
}; |
std::unique_ptr<SecurityKeyAuthHandler> SecurityKeyAuthHandler::Create( |
ClientSessionDetails* client_session_details, |
- const SendMessageCallback& send_message_callback) { |
+ const SendMessageCallback& send_message_callback, |
+ scoped_refptr<base::SingleThreadTaskRunner> file_task_runner) { |
std::unique_ptr<SecurityKeyAuthHandler> auth_handler( |
- new SecurityKeyAuthHandlerLinux()); |
+ new SecurityKeyAuthHandlerLinux(file_task_runner)); |
auth_handler->SetSendMessageCallback(send_message_callback); |
return auth_handler; |
} |
@@ -126,54 +197,160 @@ void SecurityKeyAuthHandler::SetSecurityKeySocketName( |
g_security_key_socket_name.Get() = security_key_socket_name; |
} |
-SecurityKeyAuthHandlerLinux::SecurityKeyAuthHandlerLinux() |
- : last_connection_id_(0), |
- request_timeout_( |
- base::TimeDelta::FromSeconds(kDefaultRequestTimeoutSeconds)) {} |
+SecurityKeyAuthHandlerLinux::SecurityKeyAuthHandlerLinux( |
+ scoped_refptr<base::SingleThreadTaskRunner> file_task_runner) |
+ : request_timeout_( |
+ base::TimeDelta::FromSeconds(kDefaultRequestTimeoutSeconds)), |
+ file_task_runner_(file_task_runner) { |
+ LOG(INFO) << "** SecurityKeyAuthHandlerLinux() called on: " |
+ << base::PlatformThread::GetName(); |
+} |
SecurityKeyAuthHandlerLinux::~SecurityKeyAuthHandlerLinux() { |
- STLDeleteValues(&active_sockets_); |
+ if (core_) { |
+ file_task_runner_->DeleteSoon(FROM_HERE, core_.release()); |
+ } |
} |
void SecurityKeyAuthHandlerLinux::CreateSecurityKeyConnection() { |
DCHECK(thread_checker_.CalledOnValidThread()); |
DCHECK(!g_security_key_socket_name.Get().empty()); |
+ DCHECK(!core_); |
+ LOG(INFO) << "** CreateSecurityKeyConnection() called on: " |
+ << base::PlatformThread::GetName(); |
+ |
+ core_.reset(new SecurityKeyAuthHandlerLinux::Core(send_message_callback_, |
+ request_timeout_)); |
+ |
+ // base::Unretained is safe to use as |core_| is released in the D'tor on |
+ // this thread and destroyed on |file_task_runner_|. |
+ file_task_runner_->PostTask( |
+ FROM_HERE, base::Bind(&SecurityKeyAuthHandlerLinux::Core::StartListening, |
+ base::Unretained(core_.get()))); |
+} |
- { |
- // DeleteFile() is a blocking operation, but so is creation of the unix |
- // socket below. Consider moving this class to a different thread if this |
- // causes any problems. See crbug.com/509807. |
- // TODO(joedow): Since this code now runs as a host extension, we should |
- // perform our IO on a separate thread: crbug.com/591739 |
- base::ThreadRestrictions::ScopedAllowIO allow_io; |
- |
- // If the file already exists, a socket in use error is returned. |
- base::DeleteFile(g_security_key_socket_name.Get(), false); |
+bool SecurityKeyAuthHandlerLinux::IsValidConnectionId( |
+ int security_key_connection_id) const { |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ return core_ && core_->IsValidConnectionId(security_key_connection_id); |
+} |
+ |
+void SecurityKeyAuthHandlerLinux::SendClientResponse( |
+ int security_key_connection_id, |
+ const std::string& response) { |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ DCHECK(core_); |
+ |
+ if (!IsValidConnectionId(security_key_connection_id)) { |
+ LOG(WARNING) << "Unknown gnubby-auth data connection: '" |
+ << security_key_connection_id << "'"; |
+ return; |
} |
- HOST_LOG << "Listening for security key requests on " |
- << g_security_key_socket_name.Get().value(); |
+ // base::Unretained is safe to use as |core_| is released in the D'tor on |
+ // this thread and destroyed on |file_task_runner_|. |
+ file_task_runner_->PostTask( |
+ FROM_HERE, |
+ base::Bind(&SecurityKeyAuthHandlerLinux::Core::SendClientResponse, |
+ base::Unretained(core_.get()), security_key_connection_id, |
+ base::ConstRef(response))); |
+} |
+ |
+void SecurityKeyAuthHandlerLinux::SendErrorAndCloseConnection( |
+ int security_key_connection_id) { |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ DCHECK(core_); |
+ |
+ if (!IsValidConnectionId(security_key_connection_id)) { |
+ LOG(WARNING) << "Unknown gnubby-auth data connection: '" |
+ << security_key_connection_id << "'"; |
+ return; |
+ } |
+ |
+ // base::Unretained is safe to use as |core_| is released in the D'tor on |
+ // this thread and destroyed on |file_task_runner_|. |
+ file_task_runner_->PostTask( |
+ FROM_HERE, |
+ base::Bind( |
+ &SecurityKeyAuthHandlerLinux::Core::SendErrorAndCloseConnection, |
+ base::Unretained(core_.get()), security_key_connection_id)); |
+} |
+ |
+void SecurityKeyAuthHandlerLinux::SetSendMessageCallback( |
+ const SendMessageCallback& callback) { |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ DCHECK(!core_) << "Core object already created, callback already set."; |
+ send_message_callback_ = callback; |
+} |
+ |
+size_t SecurityKeyAuthHandlerLinux::GetActiveConnectionCountForTest() const { |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ return core_ ? core_->GetActiveConnectionCountForTest() : 0; |
+} |
+ |
+void SecurityKeyAuthHandlerLinux::SetRequestTimeoutForTest( |
+ base::TimeDelta timeout) { |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ DCHECK(!core_) << "Core object already created, timeout already set."; |
+ request_timeout_ = timeout; |
+} |
+ |
+SecurityKeyAuthHandlerLinux::Core::Core( |
+ const SendMessageCallback& send_message_callback, |
+ base::TimeDelta timeout) |
+ : send_message_callback_(send_message_callback), |
+ request_timeout_(timeout), |
+ weak_factory_(this) { |
+ LOG(INFO) << "** Core() called on: " << base::PlatformThread::GetName(); |
+ thread_checker_.DetachFromThread(); |
+ main_task_runner_ = base::ThreadTaskRunnerHandle::Get(); |
+} |
+ |
+SecurityKeyAuthHandlerLinux::Core::~Core() { |
+ LOG(INFO) << "** ~Core() called on: " << base::PlatformThread::GetName(); |
+} |
+ |
+void SecurityKeyAuthHandlerLinux::Core::StartListening() { |
+ LOG(INFO) << "** Core::StartListening() called on: " |
+ << base::PlatformThread::GetName(); |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ DCHECK(!auth_socket_); |
+ |
+ // If the file already exists, a socket in use error is returned. |
+ base::DeleteFile(g_security_key_socket_name.Get(), false); |
auth_socket_.reset( |
new net::UnixDomainServerSocket(base::Bind(MatchUid), false)); |
+ |
int rv = auth_socket_->BindAndListen(g_security_key_socket_name.Get().value(), |
/*backlog=*/1); |
if (rv != net::OK) { |
LOG(ERROR) << "Failed to open socket for auth requests: '" << rv << "'"; |
return; |
} |
+ |
+ HOST_LOG << "Listening for security key requests on " |
+ << g_security_key_socket_name.Get().value(); |
+ |
DoAccept(); |
} |
-bool SecurityKeyAuthHandlerLinux::IsValidConnectionId( |
+bool SecurityKeyAuthHandlerLinux::Core::IsValidConnectionId( |
int security_key_connection_id) const { |
+ LOG(INFO) << "** Core::IsValidConnectionId() called on: " |
+ << base::PlatformThread::GetName(); |
+ base::AutoLock auto_lock(active_sockets_lock_); |
return GetSocketForConnectionId(security_key_connection_id) != |
active_sockets_.end(); |
} |
-void SecurityKeyAuthHandlerLinux::SendClientResponse( |
+void SecurityKeyAuthHandlerLinux::Core::SendClientResponse( |
int security_key_connection_id, |
const std::string& response) { |
+ LOG(INFO) << "** Core::SendClientResponse() called on: " |
+ << base::PlatformThread::GetName(); |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ |
ActiveSockets::const_iterator iter = |
GetSocketForConnectionId(security_key_connection_id); |
if (iter != active_sockets_.end()) { |
@@ -184,8 +361,12 @@ void SecurityKeyAuthHandlerLinux::SendClientResponse( |
} |
} |
-void SecurityKeyAuthHandlerLinux::SendErrorAndCloseConnection( |
+void SecurityKeyAuthHandlerLinux::Core::SendErrorAndCloseConnection( |
int security_key_connection_id) { |
+ LOG(INFO) << "** Core::SendErrorAndCloseConnection() called on: " |
+ << base::PlatformThread::GetName(); |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ |
ActiveSockets::const_iterator iter = |
GetSocketForConnectionId(security_key_connection_id); |
if (iter != active_sockets_.end()) { |
@@ -197,29 +378,29 @@ void SecurityKeyAuthHandlerLinux::SendErrorAndCloseConnection( |
} |
} |
-void SecurityKeyAuthHandlerLinux::SetSendMessageCallback( |
- const SendMessageCallback& callback) { |
- send_message_callback_ = callback; |
-} |
- |
-size_t SecurityKeyAuthHandlerLinux::GetActiveConnectionCountForTest() const { |
+size_t SecurityKeyAuthHandlerLinux::Core::GetActiveConnectionCountForTest() |
+ const { |
+ base::AutoLock auto_lock(active_sockets_lock_); |
return active_sockets_.size(); |
-} |
+}; |
-void SecurityKeyAuthHandlerLinux::SetRequestTimeoutForTest( |
- base::TimeDelta timeout) { |
- request_timeout_ = timeout; |
-} |
+void SecurityKeyAuthHandlerLinux::Core::DoAccept() { |
+ LOG(INFO) << "** Core::DoAccept() called on: " |
+ << base::PlatformThread::GetName(); |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
-void SecurityKeyAuthHandlerLinux::DoAccept() { |
int result = auth_socket_->Accept( |
- &accept_socket_, base::Bind(&SecurityKeyAuthHandlerLinux::OnAccepted, |
- base::Unretained(this))); |
- if (result != net::ERR_IO_PENDING) |
+ &accept_socket_, |
+ base::Bind(&SecurityKeyAuthHandlerLinux::Core::OnAccepted, |
+ weak_factory_.GetWeakPtr())); |
+ if (result != net::ERR_IO_PENDING) { |
OnAccepted(result); |
+ } |
} |
-void SecurityKeyAuthHandlerLinux::OnAccepted(int result) { |
+void SecurityKeyAuthHandlerLinux::Core::OnAccepted(int result) { |
+ LOG(INFO) << "** Core::OnAccepted() called on: " |
+ << base::PlatformThread::GetName(); |
DCHECK(thread_checker_.CalledOnValidThread()); |
DCHECK_NE(net::ERR_IO_PENDING, result); |
@@ -231,19 +412,24 @@ void SecurityKeyAuthHandlerLinux::OnAccepted(int result) { |
int security_key_connection_id = ++last_connection_id_; |
SecurityKeySocket* socket = new SecurityKeySocket( |
std::move(accept_socket_), request_timeout_, |
- base::Bind(&SecurityKeyAuthHandlerLinux::RequestTimedOut, |
- base::Unretained(this), security_key_connection_id)); |
- active_sockets_[security_key_connection_id] = socket; |
+ base::Bind(&SecurityKeyAuthHandlerLinux::Core::RequestTimedOut, |
+ weak_factory_.GetWeakPtr(), security_key_connection_id)); |
+ { |
+ base::AutoLock auto_lock(active_sockets_lock_); |
+ active_sockets_[security_key_connection_id] = base::WrapUnique(socket); |
+ } |
socket->StartReadingRequest( |
- base::Bind(&SecurityKeyAuthHandlerLinux::OnReadComplete, |
- base::Unretained(this), security_key_connection_id)); |
+ base::Bind(&SecurityKeyAuthHandlerLinux::Core::OnReadComplete, |
+ weak_factory_.GetWeakPtr(), security_key_connection_id)); |
// Continue accepting new connections. |
DoAccept(); |
} |
-void SecurityKeyAuthHandlerLinux::OnReadComplete( |
+void SecurityKeyAuthHandlerLinux::Core::OnReadComplete( |
int security_key_connection_id) { |
+ LOG(INFO) << "** Core::OnReadComplete() called on: " |
+ << base::PlatformThread::GetName(); |
DCHECK(thread_checker_.CalledOnValidThread()); |
ActiveSockets::const_iterator iter = |
@@ -256,33 +442,43 @@ void SecurityKeyAuthHandlerLinux::OnReadComplete( |
} |
HOST_LOG << "Received security key request: " << GetCommandCode(request_data); |
- send_message_callback_.Run(security_key_connection_id, request_data); |
+ main_task_runner_->PostTask( |
+ FROM_HERE, base::Bind(send_message_callback_, security_key_connection_id, |
+ request_data)); |
iter->second->StartReadingRequest( |
- base::Bind(&SecurityKeyAuthHandlerLinux::OnReadComplete, |
- base::Unretained(this), security_key_connection_id)); |
+ base::Bind(&SecurityKeyAuthHandlerLinux::Core::OnReadComplete, |
+ weak_factory_.GetWeakPtr(), security_key_connection_id)); |
} |
-SecurityKeyAuthHandlerLinux::ActiveSockets::const_iterator |
-SecurityKeyAuthHandlerLinux::GetSocketForConnectionId( |
+SecurityKeyAuthHandlerLinux::Core::ActiveSockets::const_iterator |
+SecurityKeyAuthHandlerLinux::Core::GetSocketForConnectionId( |
int security_key_connection_id) const { |
return active_sockets_.find(security_key_connection_id); |
} |
-void SecurityKeyAuthHandlerLinux::SendErrorAndCloseActiveSocket( |
+void SecurityKeyAuthHandlerLinux::Core::SendErrorAndCloseActiveSocket( |
const ActiveSockets::const_iterator& iter) { |
+ LOG(INFO) << "** Core::SendErrorAndCloseActiveSocket() called on: " |
+ << base::PlatformThread::GetName(); |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
+ |
iter->second->SendSshError(); |
- delete iter->second; |
- active_sockets_.erase(iter); |
+ { |
+ base::AutoLock auto_lock(active_sockets_lock_); |
+ active_sockets_.erase(iter); |
+ } |
} |
-void SecurityKeyAuthHandlerLinux::RequestTimedOut( |
+void SecurityKeyAuthHandlerLinux::Core::RequestTimedOut( |
int security_key_connection_id) { |
+ DCHECK(thread_checker_.CalledOnValidThread()); |
HOST_LOG << "SecurityKey request timed out"; |
ActiveSockets::const_iterator iter = |
- active_sockets_.find(security_key_connection_id); |
- if (iter != active_sockets_.end()) |
+ GetSocketForConnectionId(security_key_connection_id); |
+ if (iter != active_sockets_.end()) { |
SendErrorAndCloseActiveSocket(iter); |
+ } |
} |
} // namespace remoting |