Index: remoting/host/it2me/it2me_host.cc |
diff --git a/remoting/host/it2me/it2me_host.cc b/remoting/host/it2me/it2me_host.cc |
index f95c7aae786491f42cc30b9c07d952675c41a0a5..e27a46f8da93956cb05888f715e63fecba779761 100644 |
--- a/remoting/host/it2me/it2me_host.cc |
+++ b/remoting/host/it2me/it2me_host.cc |
@@ -180,6 +180,21 @@ void It2MeHost::FinishConnect() { |
return; |
} |
+ if (!required_host_domain_list_.empty()) { |
+ bool matched = false; |
+ for (const std::string& domain : required_client_domain_list_) { |
+ if (base::EndsWith(username_, std::string("@") + domain, |
+ base::CompareCase::INSENSITIVE_ASCII)) { |
+ matched = true; |
+ break; |
+ } |
+ } |
+ if (!matched) { |
+ SetState(kInvalidDomainError, ""); |
+ return; |
+ } |
+ } |
+ |
// Generate a key pair for the Host to use. |
// TODO(wez): Move this to the worker thread. |
host_key_pair_ = RsaKeyPair::Generate(); |
@@ -321,11 +336,41 @@ void It2MeHost::OnPolicyUpdate( |
if (policies->GetString(policy::key::kRemoteAccessHostDomain, &host_domain)) { |
UpdateHostDomainPolicy(host_domain); |
} |
+ const base::ListValue* host_domain_list; |
+ if (policies->GetList(policy::key::kRemoteAccessHostDomainList, |
+ &host_domain_list)) { |
+ std::vector<std::string> host_domain_list_vector; |
+ for (const auto& value : *host_domain_list) { |
+ const base::StringValue* domain; |
+ if(!value->GetAsString(&domain)) { |
+ // Should be prevented by policy validation |
+ DCHECK(false); |
+ continue; |
+ } |
+ host_domain_list_vector.push_back(domain->GetString()); |
+ } |
+ UpdateHostDomainListPolicy(std::move(host_domain_list_vector)); |
+ } |
std::string client_domain; |
if (policies->GetString(policy::key::kRemoteAccessHostClientDomain, |
&client_domain)) { |
UpdateClientDomainPolicy(client_domain); |
} |
+ const base::ListValue* client_domain_list; |
+ if (policies->GetList(policy::key::kRemoteAccessHostClientDomainList, |
+ &client_domain_list)) { |
+ std::vector<std::string> client_domain_list_vector; |
+ for (const auto& value : *client_domain_list) { |
+ const base::StringValue* domain; |
+ if(!value->GetAsString(&domain)) { |
+ // Should be prevented by policy validation |
+ DCHECK(false); |
+ continue; |
+ } |
+ client_domain_list_vector.push_back(domain->GetString()); |
+ } |
+ UpdateClientDomainListPolicy(std::move(client_domain_list_vector)); |
+ } |
policy_received_ = true; |
@@ -371,6 +416,21 @@ void It2MeHost::UpdateHostDomainPolicy(const std::string& host_domain) { |
required_host_domain_ = host_domain; |
} |
+void It2MeHost::UpdateHostDomainListPolicy( |
+ std::vector<std::string> host_domain_list) { |
+ DCHECK(host_context_->network_task_runner()->BelongsToCurrentThread()); |
+ |
+ VLOG(2) << "UpdateHostDomainListPolicy: " |
+ << base::JoinString(host_domain_list, ", "); |
+ |
+ // When setting a host domain policy, force disconnect any existing session. |
+ if (!host_domain_list.empty() && IsRunning()) { |
+ DisconnectOnNetworkThread(); |
+ } |
+ |
+ required_host_domain_list_ = std::move(host_domain_list); |
+} |
+ |
void It2MeHost::UpdateClientDomainPolicy(const std::string& client_domain) { |
DCHECK(host_context_->network_task_runner()->BelongsToCurrentThread()); |
@@ -384,6 +444,21 @@ void It2MeHost::UpdateClientDomainPolicy(const std::string& client_domain) { |
required_client_domain_ = client_domain; |
} |
+void It2MeHost::UpdateClientDomainListPolicy( |
+ std::vector<std::string> client_domain_list) { |
+ DCHECK(host_context_->network_task_runner()->BelongsToCurrentThread()); |
+ |
+ VLOG(2) << "UpdateClientDomainListPolicy: " |
+ << base::JoinString(client_domain_list, ", "); |
+ |
+ // When setting a client domain policy, disconnect any existing session. |
+ if (!client_domain_list.empty() && IsRunning()) { |
+ DisconnectOnNetworkThread(); |
+ } |
+ |
+ required_client_domain_list_ = std::move(client_domain_list); |
+} |
+ |
void It2MeHost::SetState(It2MeHostState state, |
const std::string& error_message) { |
DCHECK(host_context_->network_task_runner()->BelongsToCurrentThread()); |
@@ -517,6 +592,24 @@ void It2MeHost::ValidateConnectionDetails( |
} |
} |
+ if (!required_client_domain_list_.empty()) { |
+ bool matched = false; |
+ for (const std::string& domain : required_client_domain_list_) { |
+ if (base::EndsWith(client_username, std::string("@") + domain, |
+ base::CompareCase::INSENSITIVE_ASCII)) { |
+ matched = true; |
+ break; |
+ } |
+ } |
+ if (!matched) { |
+ LOG(ERROR) << "Rejecting incoming connection from " << remote_jid |
+ << ": Domain not allowed."; |
+ result_callback.Run(ValidationResult::ERROR_INVALID_ACCOUNT); |
+ DisconnectOnNetworkThread(); |
+ return; |
+ } |
+ } |
+ |
HOST_LOG << "Client " << client_username << " connecting."; |
SetState(kConnecting, std::string()); |