Index: remoting/host/it2me/it2me_native_messaging_host_unittest.cc |
diff --git a/remoting/host/it2me/it2me_native_messaging_host_unittest.cc b/remoting/host/it2me/it2me_native_messaging_host_unittest.cc |
index 830e2b38f2ad800f5984d5bbcf25f1448eda9f0f..5c2c4b83e70d77a011f05d198c01ceafc7aebb29 100644 |
--- a/remoting/host/it2me/it2me_native_messaging_host_unittest.cc |
+++ b/remoting/host/it2me/it2me_native_messaging_host_unittest.cc |
@@ -24,6 +24,7 @@ |
#include "base/values.h" |
#include "components/policy/core/common/fake_async_policy_loader.h" |
#include "components/policy/core/common/mock_policy_service.h" |
+#include "components/policy/policy_constants.h" |
#include "net/base/file_stream.h" |
#include "remoting/base/auto_thread_task_runner.h" |
#include "remoting/host/chromoting_host_context.h" |
@@ -195,6 +196,7 @@ class It2MeNativeMessagingHostTest : public testing::Test { |
void TearDown() override; |
protected: |
+ void SetPolicies(const base::DictionaryValue& dict); |
std::unique_ptr<base::DictionaryValue> ReadMessageFromOutputPipe(); |
void WriteMessageToInputPipe(const base::Value& message); |
@@ -202,6 +204,7 @@ class It2MeNativeMessagingHostTest : public testing::Test { |
void VerifyErrorResponse(); |
void VerifyConnectResponses(int request_id); |
void VerifyDisconnectResponses(int request_id); |
+ void VerifyPolicyErrorResponse(); |
// The Host process should shut down when it receives a malformed request. |
// This is tested by sending a known-good request, followed by |message|, |
@@ -210,9 +213,13 @@ class It2MeNativeMessagingHostTest : public testing::Test { |
void TestBadRequest(const base::Value& message, bool expect_error_response); |
void TestConnect(); |
+ void SendConnectMessage(int id); |
+ void SendDisconnectMessage(int id); |
+ |
private: |
void StartHost(); |
void ExitTest(); |
+ void ExitPolicyRunLoop(); |
// Each test creates two unidirectional pipes: "input" and "output". |
// It2MeNativeMessagingHost reads from input_read_file and writes to |
@@ -230,6 +237,12 @@ class It2MeNativeMessagingHostTest : public testing::Test { |
std::unique_ptr<base::Thread> host_thread_; |
std::unique_ptr<base::RunLoop> host_run_loop_; |
+ std::unique_ptr<base::RunLoop> policy_run_loop_; |
+ |
+ // Retain a raw pointer to |policy_loader_| in order to control the policy |
+ // contents. |
+ policy::FakeAsyncPolicyLoader* policy_loader_ = nullptr; |
+ |
// Task runner of the host thread. |
scoped_refptr<AutoThreadTaskRunner> host_task_runner_; |
std::unique_ptr<remoting::NativeMessagingPipe> pipe_; |
@@ -281,6 +294,26 @@ void It2MeNativeMessagingHostTest::TearDown() { |
output_read_file_.Close(); |
} |
+void It2MeNativeMessagingHostTest::SetPolicies( |
+ const base::DictionaryValue& dict) { |
+ DCHECK(test_message_loop_->task_runner()->RunsTasksOnCurrentThread()); |
+ // Copy |dict| into |policy_bundle|. |
+ policy::PolicyNamespace policy_namespace = |
+ policy::PolicyNamespace(policy::POLICY_DOMAIN_CHROME, std::string()); |
+ policy::PolicyBundle policy_bundle; |
+ policy::PolicyMap& policy_map = policy_bundle.Get(policy_namespace); |
+ policy_map.LoadFrom(&dict, policy::POLICY_LEVEL_MANDATORY, |
+ policy::POLICY_SCOPE_MACHINE, |
+ policy::POLICY_SOURCE_CLOUD); |
+ |
+ // Simulate a policy update and wait for it to complete. |
+ policy_run_loop_.reset(new base::RunLoop); |
+ policy_loader_->SetPolicies(policy_bundle); |
+ policy_loader_->PostReloadOnBackgroundThread(true /* force reload asap */); |
+ policy_run_loop_->Run(); |
+ policy_run_loop_.reset(nullptr); |
+} |
+ |
std::unique_ptr<base::DictionaryValue> |
It2MeNativeMessagingHostTest::ReadMessageFromOutputPipe() { |
while (true) { |
@@ -441,6 +474,14 @@ void It2MeNativeMessagingHostTest::VerifyDisconnectResponses(int request_id) { |
} |
} |
+void It2MeNativeMessagingHostTest::VerifyPolicyErrorResponse() { |
+ std::unique_ptr<base::DictionaryValue> response = ReadMessageFromOutputPipe(); |
+ ASSERT_TRUE(response); |
+ std::string type; |
+ ASSERT_TRUE(response->GetString("type", &type)); |
+ ASSERT_EQ("policyError", type); |
+} |
+ |
void It2MeNativeMessagingHostTest::TestBadRequest(const base::Value& message, |
bool expect_error_response) { |
base::DictionaryValue good_message; |
@@ -476,17 +517,24 @@ void It2MeNativeMessagingHostTest::StartHost() { |
new PipeMessagingChannel(std::move(input_read_file), |
std::move(output_write_file))); |
- // Creating a native messaging host with a mock It2MeHostFactory. |
+ // Creating a native messaging host with a mock It2MeHostFactory and policy |
+ // loader. |
std::unique_ptr<ChromotingHostContext> context = |
ChromotingHostContext::Create(host_task_runner_); |
+ auto policy_loader = |
+ base::MakeUnique<policy::FakeAsyncPolicyLoader>(host_task_runner_); |
+ policy_loader_ = policy_loader.get(); |
std::unique_ptr<PolicyWatcher> policy_watcher = |
- PolicyWatcher::CreateFromPolicyLoaderForTesting( |
- base::MakeUnique<policy::FakeAsyncPolicyLoader>( |
- base::ThreadTaskRunnerHandle::Get())); |
- std::unique_ptr<extensions::NativeMessageHost> it2me_host( |
+ PolicyWatcher::CreateFromPolicyLoaderForTesting(std::move(policy_loader)); |
+ std::unique_ptr<It2MeNativeMessagingHost> it2me_host( |
new It2MeNativeMessagingHost( |
/*needs_elevation=*/false, std::move(policy_watcher), |
std::move(context), base::WrapUnique(new MockIt2MeHostFactory()))); |
+ it2me_host->SetPolicyErrorClosureForTesting( |
+ base::Bind(base::IgnoreResult(&base::TaskRunner::PostTask), |
+ test_message_loop_->task_runner(), FROM_HERE, |
+ base::Bind(&It2MeNativeMessagingHostTest::ExitPolicyRunLoop, |
+ base::Unretained(this)))); |
it2me_host->Start(pipe_.get()); |
pipe_->Start(std::move(it2me_host), std::move(channel)); |
@@ -507,12 +555,16 @@ void It2MeNativeMessagingHostTest::ExitTest() { |
test_run_loop_->Quit(); |
} |
-void It2MeNativeMessagingHostTest::TestConnect() { |
- base::DictionaryValue connect_message; |
- int next_id = 0; |
+void It2MeNativeMessagingHostTest::ExitPolicyRunLoop() { |
+ DCHECK(test_message_loop_->task_runner()->RunsTasksOnCurrentThread()); |
+ if (policy_run_loop_) { |
+ policy_run_loop_->Quit(); |
+ } |
+} |
- // Send the "connect" request. |
- connect_message.SetInteger("id", ++next_id); |
+void It2MeNativeMessagingHostTest::SendConnectMessage(int id) { |
+ base::DictionaryValue connect_message; |
+ connect_message.SetInteger("id", id); |
connect_message.SetString("type", "connect"); |
connect_message.SetString("xmppServerAddress", "talk.google.com:5222"); |
connect_message.SetBoolean("xmppServerUseTls", true); |
@@ -520,14 +572,21 @@ void It2MeNativeMessagingHostTest::TestConnect() { |
connect_message.SetString("userName", "chromo.pyauto@gmail.com"); |
connect_message.SetString("authServiceWithToken", "oauth2:sometoken"); |
WriteMessageToInputPipe(connect_message); |
+} |
- VerifyConnectResponses(next_id); |
- |
+void It2MeNativeMessagingHostTest::SendDisconnectMessage(int id) { |
base::DictionaryValue disconnect_message; |
- disconnect_message.SetInteger("id", ++next_id); |
+ disconnect_message.SetInteger("id", id); |
disconnect_message.SetString("type", "disconnect"); |
WriteMessageToInputPipe(disconnect_message); |
+} |
+void It2MeNativeMessagingHostTest::TestConnect() { |
+ int next_id = 1; |
+ SendConnectMessage(next_id); |
+ VerifyConnectResponses(next_id); |
+ ++next_id; |
+ SendDisconnectMessage(next_id); |
VerifyDisconnectResponses(next_id); |
} |
@@ -590,4 +649,23 @@ TEST_F(It2MeNativeMessagingHostTest, InvalidType) { |
TestBadRequest(message, true); |
} |
+// Verify rejection if type is unrecognized. |
+TEST_F(It2MeNativeMessagingHostTest, BadPoliciesBeforeConnect) { |
+ base::DictionaryValue bad_policy; |
+ bad_policy.SetInteger(policy::key::kRemoteAccessHostFirewallTraversal, 1); |
+ SetPolicies(bad_policy); |
+ SendConnectMessage(1); |
+ VerifyPolicyErrorResponse(); |
+} |
+ |
+// Verify rejection if type is unrecognized. |
+TEST_F(It2MeNativeMessagingHostTest, BadPoliciesAfterConnect) { |
+ base::DictionaryValue bad_policy; |
+ bad_policy.SetInteger(policy::key::kRemoteAccessHostFirewallTraversal, 1); |
+ SendConnectMessage(1); |
+ VerifyConnectResponses(1); |
+ SetPolicies(bad_policy); |
+ VerifyPolicyErrorResponse(); |
+} |
+ |
} // namespace remoting |