Index: remoting/host/it2me/it2me_host.cc |
diff --git a/remoting/host/it2me/it2me_host.cc b/remoting/host/it2me/it2me_host.cc |
index 5856290ab3da28f7a884dd6cd5d1b98c99ca11bf..84bc3e3dba93d66770cf102bc15cfb3e9d967a75 100644 |
--- a/remoting/host/it2me/it2me_host.cc |
+++ b/remoting/host/it2me/it2me_host.cc |
@@ -174,11 +174,19 @@ void It2MeHost::FinishConnect() { |
} |
// Check the host domain policy. |
- if (!required_host_domain_.empty() && |
- !base::EndsWith(username_, std::string("@") + required_host_domain_, |
- base::CompareCase::INSENSITIVE_ASCII)) { |
- SetState(kInvalidDomainError, ""); |
- return; |
+ if (!required_host_domain_list_.empty()) { |
+ bool matched = false; |
+ for (const auto& domain : required_host_domain_list_) { |
+ if (base::EndsWith(username_, std::string("@") + domain, |
+ base::CompareCase::INSENSITIVE_ASCII)) { |
+ matched = true; |
+ break; |
+ } |
+ } |
+ if (!matched) { |
+ SetState(kInvalidDomainError, ""); |
+ return; |
+ } |
} |
// Generate a key pair for the Host to use. |
@@ -314,14 +322,23 @@ void It2MeHost::OnPolicyUpdate( |
&nat_policy)) { |
UpdateNatPolicy(nat_policy); |
} |
- std::string host_domain; |
- if (policies->GetString(policy::key::kRemoteAccessHostDomain, &host_domain)) { |
- UpdateHostDomainPolicy(host_domain); |
+ const base::ListValue* host_domain_list; |
+ if (policies->GetList(policy::key::kRemoteAccessHostDomainList, |
+ &host_domain_list)) { |
+ std::vector<std::string> host_domain_list_vector; |
+ for (const auto& value : *host_domain_list) { |
+ host_domain_list_vector.push_back(value.GetString()); |
+ } |
+ UpdateHostDomainListPolicy(std::move(host_domain_list_vector)); |
} |
- std::string client_domain; |
- if (policies->GetString(policy::key::kRemoteAccessHostClientDomain, |
- &client_domain)) { |
- UpdateClientDomainPolicy(client_domain); |
+ const base::ListValue* client_domain_list; |
+ if (policies->GetList(policy::key::kRemoteAccessHostClientDomainList, |
+ &client_domain_list)) { |
+ std::vector<std::string> client_domain_list_vector; |
+ for (const auto& value : *client_domain_list) { |
+ client_domain_list_vector.push_back(value.GetString()); |
+ } |
+ UpdateClientDomainListPolicy(std::move(client_domain_list_vector)); |
} |
policy_received_ = true; |
@@ -355,30 +372,34 @@ void It2MeHost::UpdateNatPolicy(bool nat_traversal_enabled) { |
nat_traversal_enabled_)); |
} |
-void It2MeHost::UpdateHostDomainPolicy(const std::string& host_domain) { |
+void It2MeHost::UpdateHostDomainListPolicy( |
+ std::vector<std::string> host_domain_list) { |
DCHECK(host_context_->network_task_runner()->BelongsToCurrentThread()); |
- VLOG(2) << "UpdateHostDomainPolicy: " << host_domain; |
+ VLOG(2) << "UpdateHostDomainListPolicy: " |
+ << base::JoinString(host_domain_list, ", "); |
// When setting a host domain policy, force disconnect any existing session. |
- if (!host_domain.empty() && IsRunning()) { |
+ if (!host_domain_list.empty() && IsRunning()) { |
DisconnectOnNetworkThread(); |
} |
- required_host_domain_ = host_domain; |
+ required_host_domain_list_ = std::move(host_domain_list); |
} |
-void It2MeHost::UpdateClientDomainPolicy(const std::string& client_domain) { |
+void It2MeHost::UpdateClientDomainListPolicy( |
+ std::vector<std::string> client_domain_list) { |
DCHECK(host_context_->network_task_runner()->BelongsToCurrentThread()); |
- VLOG(2) << "UpdateClientDomainPolicy: " << client_domain; |
+ VLOG(2) << "UpdateClientDomainPolicy: " |
+ << base::JoinString(client_domain_list, ", "); |
// When setting a client domain policy, disconnect any existing session. |
- if (!client_domain.empty() && IsRunning()) { |
+ if (!client_domain_list.empty() && IsRunning()) { |
DisconnectOnNetworkThread(); |
} |
- required_client_domain_ = client_domain; |
+ required_client_domain_list_ = std::move(client_domain_list); |
} |
void It2MeHost::SetState(It2MeHostState state, |
@@ -502,12 +523,18 @@ void It2MeHost::ValidateConnectionDetails( |
} |
// Check the client domain policy. |
- if (!required_client_domain_.empty()) { |
- if (!base::EndsWith(client_username, |
- std::string("@") + required_client_domain_, |
- base::CompareCase::INSENSITIVE_ASCII)) { |
+ if (!required_client_domain_list_.empty()) { |
+ bool matched = false; |
+ for (const auto& domain : required_client_domain_list_) { |
+ if (base::EndsWith(client_username, std::string("@") + domain, |
+ base::CompareCase::INSENSITIVE_ASCII)) { |
+ matched = true; |
+ break; |
+ } |
+ } |
+ if (!matched) { |
LOG(ERROR) << "Rejecting incoming connection from " << remote_jid |
- << ": Domain mismatch."; |
+ << ": Domain not allowed."; |
result_callback.Run(ValidationResult::ERROR_INVALID_ACCOUNT); |
DisconnectOnNetworkThread(); |
return; |