Index: components/certificate_transparency/log_dns_client.cc |
diff --git a/components/certificate_transparency/log_dns_client.cc b/components/certificate_transparency/log_dns_client.cc |
index 76ff9f706bb3387692590cf30d5aa931fe7fb39d..60e82304fbe1e98e80c42862318cd6333cf21077 100644 |
--- a/components/certificate_transparency/log_dns_client.cc |
+++ b/components/certificate_transparency/log_dns_client.cc |
@@ -6,10 +6,11 @@ |
#include "base/bind.h" |
#include "base/format_macros.h" |
#include "base/location.h" |
#include "base/logging.h" |
+#include "base/memory/ptr_util.h" |
#include "base/strings/string_number_conversions.h" |
#include "base/strings/string_util.h" |
#include "base/strings/stringprintf.h" |
#include "base/threading/thread_task_runner_handle.h" |
#include "base/time/time.h" |
@@ -95,10 +96,230 @@ bool ParseAuditPath(const net::DnsResponse& response, |
return true; |
} |
} // namespace |
+// Encapsulates the state machine required to get an audit proof from a Merkle |
+// leaf hash. This requires a DNS request to obtain the leaf index, then a |
+// series of DNS requests to get the nodes of the proof. |
+class LogDnsClient::AuditProofQuery { |
+ public: |
+ using CompletionCallback = |
+ base::Callback<void(int net_error, AuditProofQuery* query)>; |
+ |
+ // The LogDnsClient is guaranteed to outlive the AuditProofQuery, so it's safe |
+ // to leave ownership of |dns_client| with LogDnsClient. |
+ AuditProofQuery(net::DnsClient* dns_client, |
+ const std::string& domain_for_log, |
+ uint64_t tree_size, |
+ const net::NetLogWithSource& net_log); |
+ |
+ // Begins the process of getting an audit proof for the CT log entry with a |
+ // leaf hash of |leaf_hash|. The |callback| will be invoked when finished. |
+ void Start(base::StringPiece leaf_hash, CompletionCallback callback); |
+ |
+ // Transfers the audit proof to the caller. |
+ // Only call this once the query has completed, otherwise the proof will be |
+ // incomplete. |
+ std::unique_ptr<net::ct::MerkleAuditProof> TakeProof(); |
+ |
+ private: |
+ // Requests the leaf index of the CT log entry with |leaf_hash|. |
+ void QueryLeafIndex(base::StringPiece leaf_hash); |
+ |
+ // Processes the response to a leaf index request. |
+ // The received leaf index will be added to the proof. |
+ void QueryLeafIndexComplete(net::DnsTransaction* transaction, |
+ int net_error, |
+ const net::DnsResponse* response); |
+ |
+ // Queries a CT log to retrieve part of an audit proof. The |node_index| |
+ // indicates which node of the audit proof/ should be requested. The CT log |
+ // may return up to 7 nodes, starting from |node_index| (this is the maximum |
+ // that will fit in a DNS UDP packet). The nodes will be appended to the |
+ // proof. |
+ void QueryAuditProofNodes(uint64_t node_index); |
+ |
+ // Processes the response to an audit proof request. |
+ // This will contain some, but not necessarily all, of the audit proof nodes. |
+ void QueryAuditProofNodesComplete(net::DnsTransaction* transaction, |
+ int net_error, |
+ const net::DnsResponse* response); |
+ |
+ std::string domain_for_log_; |
+ // TODO(robpercival): Remove |tree_size| once |proof_| has a tree_size member. |
+ uint64_t tree_size_; |
+ std::unique_ptr<net::ct::MerkleAuditProof> proof_; |
+ CompletionCallback callback_; |
+ net::DnsClient* dns_client_; |
+ std::unique_ptr<net::DnsTransaction> current_dns_transaction_; |
+ net::NetLogWithSource net_log_; |
+ base::WeakPtrFactory<AuditProofQuery> weak_ptr_factory_; |
+}; |
+ |
+LogDnsClient::AuditProofQuery::AuditProofQuery( |
+ net::DnsClient* dns_client, |
+ const std::string& domain_for_log, |
+ uint64_t tree_size, |
+ const net::NetLogWithSource& net_log) |
+ : domain_for_log_(domain_for_log), |
+ tree_size_(tree_size), |
+ dns_client_(dns_client), |
+ net_log_(net_log), |
+ weak_ptr_factory_(this) { |
+ DCHECK(dns_client_); |
+ DCHECK(!domain_for_log_.empty()); |
+} |
+ |
+void LogDnsClient::AuditProofQuery::Start(base::StringPiece leaf_hash, |
+ CompletionCallback callback) { |
+ current_dns_transaction_.reset(); |
+ proof_ = base::MakeUnique<net::ct::MerkleAuditProof>(); |
+ callback_ = callback; |
+ QueryLeafIndex(leaf_hash); |
+} |
+ |
+std::unique_ptr<net::ct::MerkleAuditProof> |
+LogDnsClient::AuditProofQuery::TakeProof() { |
+ return std::move(proof_); |
+} |
+ |
+void LogDnsClient::AuditProofQuery::QueryLeafIndex( |
+ base::StringPiece leaf_hash) { |
+ std::string encoded_leaf_hash = |
+ base32::Base32Encode(leaf_hash, base32::Base32EncodePolicy::OMIT_PADDING); |
+ DCHECK_EQ(encoded_leaf_hash.size(), 52u); |
+ |
+ std::string qname = base::StringPrintf( |
+ "%s.hash.%s.", encoded_leaf_hash.c_str(), domain_for_log_.data()); |
+ |
+ net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); |
+ if (factory == nullptr) { |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(callback_, net::Error::ERR_NAME_RESOLUTION_FAILED, |
+ base::Unretained(this))); |
+ return; |
+ } |
+ |
+ net::DnsTransactionFactory::CallbackType transaction_callback = |
+ base::Bind(&LogDnsClient::AuditProofQuery::QueryLeafIndexComplete, |
+ weak_ptr_factory_.GetWeakPtr()); |
+ |
+ current_dns_transaction_ = factory->CreateTransaction( |
+ qname, net::dns_protocol::kTypeTXT, transaction_callback, net_log_); |
+ |
+ current_dns_transaction_->Start(); |
+} |
+ |
+void LogDnsClient::AuditProofQuery::QueryLeafIndexComplete( |
+ net::DnsTransaction* transaction, |
+ int net_error, |
+ const net::DnsResponse* response) { |
+ // If we've received no response but no net::error either (shouldn't |
+ // happen), |
+ // report the response as invalid. |
+ if (response == nullptr && net_error == net::OK) { |
+ net_error = net::ERR_INVALID_RESPONSE; |
+ } |
+ |
+ if (net_error != net::OK) { |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(callback_, net_error, base::Unretained(this))); |
+ return; |
+ } |
+ |
+ if (!ParseLeafIndex(*response, &proof_->leaf_index)) { |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(callback_, net::ERR_DNS_MALFORMED_RESPONSE, |
+ base::Unretained(this))); |
+ return; |
+ } |
+ |
+ // Reject leaf index if it is out-of-range. |
+ // This indicates either: |
+ // a) the wrong tree_size was provided. |
+ // b) the wrong leaf hash was provided. |
+ // c) there is a bug server-side. |
+ // The first two are more likely, so return ERR_INVALID_ARGUMENT. |
+ if (proof_->leaf_index >= tree_size_) { |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(callback_, net::ERR_INVALID_ARGUMENT, |
+ base::Unretained(this))); |
+ return; |
+ } |
+ |
+ // QueryAuditProof for the first batch of audit proof_ nodes (i.e. starting |
+ // from 0). |
+ QueryAuditProofNodes(0 /* start node index */); |
+} |
+ |
+void LogDnsClient::AuditProofQuery::QueryAuditProofNodes(uint64_t node_index) { |
+ DCHECK_LT(proof_->leaf_index, tree_size_); |
+ DCHECK_LT(node_index, |
+ net::ct::CalculateAuditPathLength(proof_->leaf_index, tree_size_)); |
+ |
+ std::string qname = base::StringPrintf( |
+ "%" PRIu64 ".%" PRIu64 ".%" PRIu64 ".tree.%s.", node_index, |
+ proof_->leaf_index, tree_size_, domain_for_log_.data()); |
+ |
+ net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); |
+ if (factory == nullptr) { |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(callback_, net::Error::ERR_NAME_RESOLUTION_FAILED, |
+ base::Unretained(this))); |
+ return; |
+ } |
+ |
+ net::DnsTransactionFactory::CallbackType transaction_callback = |
+ base::Bind(&LogDnsClient::AuditProofQuery::QueryAuditProofNodesComplete, |
+ weak_ptr_factory_.GetWeakPtr()); |
+ |
+ current_dns_transaction_ = factory->CreateTransaction( |
+ qname, net::dns_protocol::kTypeTXT, transaction_callback, net_log_); |
+ current_dns_transaction_->Start(); |
+} |
+ |
+void LogDnsClient::AuditProofQuery::QueryAuditProofNodesComplete( |
+ net::DnsTransaction* transaction, |
+ int net_error, |
+ const net::DnsResponse* response) { |
+ // If we receive no response but no net::error either (shouldn't happen), |
+ // report the response as invalid. |
+ if (response == nullptr && net_error == net::OK) { |
+ net_error = net::ERR_INVALID_RESPONSE; |
+ } |
+ |
+ if (net_error != net::OK) { |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(callback_, net_error, base::Unretained(this))); |
+ return; |
+ } |
+ |
+ const uint64_t audit_path_length = |
+ net::ct::CalculateAuditPathLength(proof_->leaf_index, tree_size_); |
+ // The calculated |audit_path_length| can't ever be greater than 64, so |
+ // deriving the amount of memory to reserve from the untrusted |leaf_index| |
+ // is safe. |
+ proof_->nodes.reserve(audit_path_length); |
+ |
+ if (!ParseAuditPath(*response, proof_.get())) { |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(callback_, net::ERR_DNS_MALFORMED_RESPONSE, |
+ base::Unretained(this))); |
+ return; |
+ } |
+ |
+ const uint64_t audit_path_nodes_received = proof_->nodes.size(); |
+ if (audit_path_nodes_received < audit_path_length) { |
+ QueryAuditProofNodes(audit_path_nodes_received); |
+ return; |
+ } |
+ |
+ base::ThreadTaskRunnerHandle::Get()->PostTask( |
+ FROM_HERE, base::Bind(callback_, net::OK, base::Unretained(this))); |
+} |
+ |
LogDnsClient::LogDnsClient(std::unique_ptr<net::DnsClient> dns_client, |
const net::NetLogWithSource& net_log, |
size_t max_concurrent_queries) |
: dns_client_(std::move(dns_client)), |
net_log_(net_log), |
@@ -119,67 +340,25 @@ void LogDnsClient::OnDNSChanged() { |
void LogDnsClient::OnInitialDNSConfigRead() { |
UpdateDnsConfig(); |
} |
-void LogDnsClient::QueryLeafIndex(base::StringPiece domain_for_log, |
- base::StringPiece leaf_hash, |
- const LeafIndexCallback& callback) { |
- if (domain_for_log.empty() || leaf_hash.size() != crypto::kSHA256Length) { |
- base::ThreadTaskRunnerHandle::Get()->PostTask( |
- FROM_HERE, base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, 0)); |
- return; |
- } |
- |
- if (HasMaxConcurrentQueriesInProgress()) { |
- base::ThreadTaskRunnerHandle::Get()->PostTask( |
- FROM_HERE, |
- base::Bind(callback, net::Error::ERR_TEMPORARILY_THROTTLED, 0)); |
- return; |
- } |
- |
- std::string encoded_leaf_hash = |
- base32::Base32Encode(leaf_hash, base32::Base32EncodePolicy::OMIT_PADDING); |
- DCHECK_EQ(encoded_leaf_hash.size(), 52u); |
- |
- net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); |
- if (factory == nullptr) { |
- base::ThreadTaskRunnerHandle::Get()->PostTask( |
- FROM_HERE, |
- base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, 0)); |
- return; |
- } |
- |
- std::string qname = base::StringPrintf( |
- "%s.hash.%s.", encoded_leaf_hash.c_str(), domain_for_log.data()); |
- |
- net::DnsTransactionFactory::CallbackType transaction_callback = base::Bind( |
- &LogDnsClient::QueryLeafIndexComplete, weak_ptr_factory_.GetWeakPtr()); |
- |
- std::unique_ptr<net::DnsTransaction> dns_transaction = |
- factory->CreateTransaction(qname, net::dns_protocol::kTypeTXT, |
- transaction_callback, net_log_); |
- |
- dns_transaction->Start(); |
- leaf_index_queries_.push_back({std::move(dns_transaction), callback}); |
-} |
- |
// The performance of this could be improved by sending all of the expected |
// queries up front. Each response can contain a maximum of 7 audit path nodes, |
// so for an audit proof of size 20, it could send 3 queries (for nodes 0-6, |
// 7-13 and 14-19) immediately. Currently, it sends only the first and then, |
// based on the number of nodes received, sends the next query. The complexity |
// of the code would increase though, as it would need to detect gaps in the |
// audit proof caused by the server not responding with the anticipated number |
// of nodes. Ownership of the proof would need to change, as it would be shared |
// between simultaneous DNS transactions. Throttling of queries would also need |
// to take into account this increase in parallelism. |
-void LogDnsClient::QueryAuditProof(base::StringPiece domain_for_log, |
- uint64_t leaf_index, |
+void LogDnsClient::QueryAuditProof(const std::string& domain_for_log, |
+ base::StringPiece leaf_hash, |
uint64_t tree_size, |
const AuditProofCallback& callback) { |
- if (domain_for_log.empty() || leaf_index >= tree_size) { |
+ if (domain_for_log.empty() || leaf_hash.size() != crypto::kSHA256Length) { |
base::ThreadTaskRunnerHandle::Get()->PostTask( |
FROM_HERE, |
base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, nullptr)); |
return; |
} |
@@ -189,162 +368,46 @@ void LogDnsClient::QueryAuditProof(base::StringPiece domain_for_log, |
FROM_HERE, |
base::Bind(callback, net::Error::ERR_TEMPORARILY_THROTTLED, nullptr)); |
return; |
} |
- std::unique_ptr<net::ct::MerkleAuditProof> proof( |
- new net::ct::MerkleAuditProof); |
- proof->leaf_index = leaf_index; |
- // TODO(robpercival): Once a "tree_size" field is added to MerkleAuditProof, |
- // pass |tree_size| to QueryAuditProofNodes using that. |
+ audit_proof_queries_.emplace_back(new AuditProofQuery( |
+ dns_client_.get(), domain_for_log, tree_size, net_log_)); |
- // Query for the first batch of audit proof nodes (i.e. starting from 0). |
- QueryAuditProofNodes(std::move(proof), domain_for_log, tree_size, 0, |
- callback); |
-} |
- |
-void LogDnsClient::QueryLeafIndexComplete(net::DnsTransaction* transaction, |
- int net_error, |
- const net::DnsResponse* response) { |
- auto query_iterator = |
- std::find_if(leaf_index_queries_.begin(), leaf_index_queries_.end(), |
- [transaction](const Query<LeafIndexCallback>& query) { |
- return query.transaction.get() == transaction; |
- }); |
- if (query_iterator == leaf_index_queries_.end()) { |
- NOTREACHED(); |
- return; |
- } |
- const Query<LeafIndexCallback> query = std::move(*query_iterator); |
- leaf_index_queries_.erase(query_iterator); |
- |
- // If we've received no response but no net::error either (shouldn't happen), |
- // report the response as invalid. |
- if (response == nullptr && net_error == net::OK) { |
- net_error = net::ERR_INVALID_RESPONSE; |
- } |
- |
- if (net_error != net::OK) { |
- base::ThreadTaskRunnerHandle::Get()->PostTask( |
- FROM_HERE, base::Bind(query.callback, net_error, 0)); |
- return; |
- } |
+ AuditProofQuery::CompletionCallback internal_callback = |
+ base::Bind(&LogDnsClient::QueryAuditProofComplete, |
+ weak_ptr_factory_.GetWeakPtr(), callback); |
- uint64_t leaf_index; |
- if (!ParseLeafIndex(*response, &leaf_index)) { |
- base::ThreadTaskRunnerHandle::Get()->PostTask( |
- FROM_HERE, |
- base::Bind(query.callback, net::ERR_DNS_MALFORMED_RESPONSE, 0)); |
- return; |
- } |
- |
- base::ThreadTaskRunnerHandle::Get()->PostTask( |
- FROM_HERE, base::Bind(query.callback, net::OK, leaf_index)); |
+ audit_proof_queries_.back()->Start(leaf_hash, internal_callback); |
} |
-void LogDnsClient::QueryAuditProofNodes( |
- std::unique_ptr<net::ct::MerkleAuditProof> proof, |
- base::StringPiece domain_for_log, |
- uint64_t tree_size, |
- uint64_t node_index, |
- const AuditProofCallback& callback) { |
- // Preconditions that should be guaranteed internally by this class. |
- DCHECK(proof); |
- DCHECK(!domain_for_log.empty()); |
- DCHECK_LT(proof->leaf_index, tree_size); |
- DCHECK_LT(node_index, |
- net::ct::CalculateAuditPathLength(proof->leaf_index, tree_size)); |
+void LogDnsClient::QueryAuditProofComplete(const AuditProofCallback& callback, |
+ int result, |
+ AuditProofQuery* query) { |
+ DCHECK(query); |
- net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); |
- if (factory == nullptr) { |
- base::ThreadTaskRunnerHandle::Get()->PostTask( |
- FROM_HERE, |
- base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, nullptr)); |
- return; |
+ std::unique_ptr<net::ct::MerkleAuditProof> proof; |
+ if (result == net::OK) { |
+ proof = query->TakeProof(); |
} |
- std::string qname = base::StringPrintf( |
- "%" PRIu64 ".%" PRIu64 ".%" PRIu64 ".tree.%s.", node_index, |
- proof->leaf_index, tree_size, domain_for_log.data()); |
- |
- net::DnsTransactionFactory::CallbackType transaction_callback = |
- base::Bind(&LogDnsClient::QueryAuditProofNodesComplete, |
- weak_ptr_factory_.GetWeakPtr(), base::Passed(std::move(proof)), |
- domain_for_log, tree_size); |
- |
- std::unique_ptr<net::DnsTransaction> dns_transaction = |
- factory->CreateTransaction(qname, net::dns_protocol::kTypeTXT, |
- transaction_callback, net_log_); |
- dns_transaction->Start(); |
- audit_proof_queries_.push_back({std::move(dns_transaction), callback}); |
-} |
- |
-void LogDnsClient::QueryAuditProofNodesComplete( |
- std::unique_ptr<net::ct::MerkleAuditProof> proof, |
- base::StringPiece domain_for_log, |
- uint64_t tree_size, |
- net::DnsTransaction* transaction, |
- int net_error, |
- const net::DnsResponse* response) { |
- // Preconditions that should be guaranteed internally by this class. |
- DCHECK(proof); |
- DCHECK(!domain_for_log.empty()); |
- |
+ // Finished with the query - destroy it. |
auto query_iterator = |
std::find_if(audit_proof_queries_.begin(), audit_proof_queries_.end(), |
- [transaction](const Query<AuditProofCallback>& query) { |
- return query.transaction.get() == transaction; |
+ [query](const std::unique_ptr<AuditProofQuery>& p) { |
+ return p.get() == query; |
}); |
- |
- if (query_iterator == audit_proof_queries_.end()) { |
- NOTREACHED(); |
- return; |
- } |
- const Query<AuditProofCallback> query = std::move(*query_iterator); |
+ DCHECK(query_iterator != audit_proof_queries_.end()); |
audit_proof_queries_.erase(query_iterator); |
- // If we've received no response but no net::error either (shouldn't happen), |
- // report the response as invalid. |
- if (response == nullptr && net_error == net::OK) { |
- net_error = net::ERR_INVALID_RESPONSE; |
- } |
- |
- if (net_error != net::OK) { |
- base::ThreadTaskRunnerHandle::Get()->PostTask( |
- FROM_HERE, base::Bind(query.callback, net_error, nullptr)); |
- return; |
- } |
- |
- const uint64_t audit_path_length = |
- net::ct::CalculateAuditPathLength(proof->leaf_index, tree_size); |
- proof->nodes.reserve(audit_path_length); |
- |
- if (!ParseAuditPath(*response, proof.get())) { |
- base::ThreadTaskRunnerHandle::Get()->PostTask( |
- FROM_HERE, |
- base::Bind(query.callback, net::ERR_DNS_MALFORMED_RESPONSE, nullptr)); |
- return; |
- } |
- |
- const uint64_t audit_path_nodes_received = proof->nodes.size(); |
- if (audit_path_nodes_received < audit_path_length) { |
- QueryAuditProofNodes(std::move(proof), domain_for_log, tree_size, |
- audit_path_nodes_received, query.callback); |
- return; |
- } |
- |
base::ThreadTaskRunnerHandle::Get()->PostTask( |
- FROM_HERE, |
- base::Bind(query.callback, net::OK, base::Passed(std::move(proof)))); |
+ FROM_HERE, base::Bind(callback, result, base::Passed(&proof))); |
} |
bool LogDnsClient::HasMaxConcurrentQueriesInProgress() const { |
- const size_t queries_in_progress = |
- leaf_index_queries_.size() + audit_proof_queries_.size(); |
- |
return max_concurrent_queries_ != 0 && |
- queries_in_progress >= max_concurrent_queries_; |
+ audit_proof_queries_.size() >= max_concurrent_queries_; |
} |
void LogDnsClient::UpdateDnsConfig() { |
net::DnsConfig config; |
net::NetworkChangeNotifier::GetDnsConfig(&config); |