Chromium Code Reviews| Index: components/certificate_transparency/log_dns_client.cc |
| diff --git a/components/certificate_transparency/log_dns_client.cc b/components/certificate_transparency/log_dns_client.cc |
| index ce7e8627a703783416b2fae22ac764c74e236400..76ff9f706bb3387692590cf30d5aa931fe7fb39d 100644 |
| --- a/components/certificate_transparency/log_dns_client.cc |
| +++ b/components/certificate_transparency/log_dns_client.cc |
| @@ -2,17 +2,17 @@ |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| #include "components/certificate_transparency/log_dns_client.h" |
| -#include <sstream> |
| - |
| #include "base/bind.h" |
| +#include "base/format_macros.h" |
|
Ryan Sleevi
2016/10/07 15:46:23
AFAICT, it should be sufficient to just include <s
Rob Percival
2016/10/11 18:02:56
Just tested this and I get a compilation error wit
|
| #include "base/location.h" |
| #include "base/logging.h" |
| #include "base/strings/string_number_conversions.h" |
| #include "base/strings/string_util.h" |
| +#include "base/strings/stringprintf.h" |
| #include "base/threading/thread_task_runner_handle.h" |
| #include "base/time/time.h" |
| #include "components/base32/base32.h" |
| #include "crypto/sha2.h" |
| #include "net/base/net_errors.h" |
| @@ -96,13 +96,15 @@ bool ParseAuditPath(const net::DnsResponse& response, |
| } |
| } // namespace |
| LogDnsClient::LogDnsClient(std::unique_ptr<net::DnsClient> dns_client, |
| - const net::NetLogWithSource& net_log) |
| + const net::NetLogWithSource& net_log, |
| + size_t max_concurrent_queries) |
| : dns_client_(std::move(dns_client)), |
| net_log_(net_log), |
| + max_concurrent_queries_(max_concurrent_queries), |
| weak_ptr_factory_(this) { |
| CHECK(dns_client_); |
| net::NetworkChangeNotifier::AddDNSObserver(this); |
| UpdateDnsConfig(); |
| } |
| @@ -126,10 +128,17 @@ void LogDnsClient::QueryLeafIndex(base::StringPiece domain_for_log, |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, 0)); |
| return; |
| } |
| + if (HasMaxConcurrentQueriesInProgress()) { |
| + base::ThreadTaskRunnerHandle::Get()->PostTask( |
| + FROM_HERE, |
| + base::Bind(callback, net::Error::ERR_TEMPORARILY_THROTTLED, 0)); |
| + return; |
| + } |
| + |
| std::string encoded_leaf_hash = |
| base32::Base32Encode(leaf_hash, base32::Base32EncodePolicy::OMIT_PADDING); |
| DCHECK_EQ(encoded_leaf_hash.size(), 52u); |
| net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); |
| @@ -138,18 +147,18 @@ void LogDnsClient::QueryLeafIndex(base::StringPiece domain_for_log, |
| FROM_HERE, |
| base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, 0)); |
| return; |
| } |
| - std::ostringstream qname; |
| - qname << encoded_leaf_hash << ".hash." << domain_for_log << "."; |
| + std::string qname = base::StringPrintf( |
| + "%s.hash.%s.", encoded_leaf_hash.c_str(), domain_for_log.data()); |
| net::DnsTransactionFactory::CallbackType transaction_callback = base::Bind( |
| &LogDnsClient::QueryLeafIndexComplete, weak_ptr_factory_.GetWeakPtr()); |
| std::unique_ptr<net::DnsTransaction> dns_transaction = |
| - factory->CreateTransaction(qname.str(), net::dns_protocol::kTypeTXT, |
| + factory->CreateTransaction(qname, net::dns_protocol::kTypeTXT, |
| transaction_callback, net_log_); |
| dns_transaction->Start(); |
| leaf_index_queries_.push_back({std::move(dns_transaction), callback}); |
| } |
| @@ -160,11 +169,12 @@ void LogDnsClient::QueryLeafIndex(base::StringPiece domain_for_log, |
| // 7-13 and 14-19) immediately. Currently, it sends only the first and then, |
| // based on the number of nodes received, sends the next query. The complexity |
| // of the code would increase though, as it would need to detect gaps in the |
| // audit proof caused by the server not responding with the anticipated number |
| // of nodes. Ownership of the proof would need to change, as it would be shared |
| -// between simultaneous DNS transactions. |
| +// between simultaneous DNS transactions. Throttling of queries would also need |
| +// to take into account this increase in parallelism. |
| void LogDnsClient::QueryAuditProof(base::StringPiece domain_for_log, |
| uint64_t leaf_index, |
| uint64_t tree_size, |
| const AuditProofCallback& callback) { |
| if (domain_for_log.empty() || leaf_index >= tree_size) { |
| @@ -172,10 +182,17 @@ void LogDnsClient::QueryAuditProof(base::StringPiece domain_for_log, |
| FROM_HERE, |
| base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, nullptr)); |
| return; |
| } |
| + if (HasMaxConcurrentQueriesInProgress()) { |
| + base::ThreadTaskRunnerHandle::Get()->PostTask( |
| + FROM_HERE, |
| + base::Bind(callback, net::Error::ERR_TEMPORARILY_THROTTLED, nullptr)); |
| + return; |
| + } |
| + |
| std::unique_ptr<net::ct::MerkleAuditProof> proof( |
| new net::ct::MerkleAuditProof); |
| proof->leaf_index = leaf_index; |
| // TODO(robpercival): Once a "tree_size" field is added to MerkleAuditProof, |
| // pass |tree_size| to QueryAuditProofNodes using that. |
| @@ -243,21 +260,21 @@ void LogDnsClient::QueryAuditProofNodes( |
| FROM_HERE, |
| base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, nullptr)); |
| return; |
| } |
| - std::ostringstream qname; |
| - qname << node_index << "." << proof->leaf_index << "." << tree_size |
| - << ".tree." << domain_for_log << "."; |
| + std::string qname = base::StringPrintf( |
| + "%" PRIu64 ".%" PRIu64 ".%" PRIu64 ".tree.%s.", node_index, |
| + proof->leaf_index, tree_size, domain_for_log.data()); |
| net::DnsTransactionFactory::CallbackType transaction_callback = |
| base::Bind(&LogDnsClient::QueryAuditProofNodesComplete, |
| weak_ptr_factory_.GetWeakPtr(), base::Passed(std::move(proof)), |
| domain_for_log, tree_size); |
| std::unique_ptr<net::DnsTransaction> dns_transaction = |
| - factory->CreateTransaction(qname.str(), net::dns_protocol::kTypeTXT, |
| + factory->CreateTransaction(qname, net::dns_protocol::kTypeTXT, |
| transaction_callback, net_log_); |
| dns_transaction->Start(); |
| audit_proof_queries_.push_back({std::move(dns_transaction), callback}); |
| } |
| @@ -318,10 +335,18 @@ void LogDnsClient::QueryAuditProofNodesComplete( |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, |
| base::Bind(query.callback, net::OK, base::Passed(std::move(proof)))); |
| } |
| +bool LogDnsClient::HasMaxConcurrentQueriesInProgress() const { |
| + const size_t queries_in_progress = |
| + leaf_index_queries_.size() + audit_proof_queries_.size(); |
| + |
| + return max_concurrent_queries_ != 0 && |
| + queries_in_progress >= max_concurrent_queries_; |
| +} |
| + |
| void LogDnsClient::UpdateDnsConfig() { |
| net::DnsConfig config; |
| net::NetworkChangeNotifier::GetDnsConfig(&config); |
| if (config.IsValid()) |
| dns_client_->SetConfig(config); |