Chromium Code Reviews| Index: components/certificate_transparency/log_dns_client.cc |
| diff --git a/components/certificate_transparency/log_dns_client.cc b/components/certificate_transparency/log_dns_client.cc |
| index 5f15ebbd71a5c2ff9f658c1814c5aaa209df82fe..712f184713a9754448dcfd8161df4a5ff524f2a4 100644 |
| --- a/components/certificate_transparency/log_dns_client.cc |
| +++ b/components/certificate_transparency/log_dns_client.cc |
| @@ -79,10 +79,196 @@ bool ParseAuditPath(const net::DnsResponse& response, |
| return true; |
| } |
| } // namespace |
| +// Encapsulates the state machine required to get an audit proof from a Merkle |
| +// leaf hash. This requires a DNS request to obtain the leaf index, then a |
| +// series of DNS requests to get the nodes of the proof. |
| +class LogDnsClient::AuditProofQuery { |
|
Ryan Sleevi
2016/10/03 23:51:07
See comments on the other CL about splitting defin
Rob Percival
2016/10/04 16:01:57
Done.
|
| + public: |
| + using CompletionCallback = |
| + base::Callback<void(int net_error, AuditProofQuery* query)>; |
| + |
| + // The LogDnsClient is guaranteed to outlive the AuditProofQuery, so it's safe |
| + // to leave ownership of |dns_client| with LogDnsClient. |
| + AuditProofQuery(net::DnsClient* dns_client, |
| + const std::string& domain_for_log, |
| + uint64_t tree_size, |
| + const net::NetLogWithSource& net_log) |
| + : domain_for_log_(domain_for_log), |
| + tree_size_(tree_size), |
| + dns_client_(dns_client), |
| + net_log_(net_log), |
| + weak_ptr_factory_(this) { |
| + DCHECK(dns_client_); |
| + DCHECK(!domain_for_log_.empty()); |
| + } |
| + |
| + // Start the query. |
| + void Start(base::StringPiece leaf_hash, CompletionCallback callback) { |
| + current_dns_transaction_.reset(); |
| + proof_.reset(new net::ct::MerkleAuditProof); |
|
Ryan Sleevi
2016/10/03 23:51:08
base::MakeUnique<> ?
Rob Percival
2016/10/04 16:01:57
Done.
|
| + callback_ = callback; |
| + QueryLeafIndex(leaf_hash); |
| + } |
| + |
| + std::unique_ptr<net::ct::MerkleAuditProof> TakeProof() { |
| + return std::move(proof_); |
| + } |
| + |
| + private: |
| + void QueryLeafIndex(base::StringPiece leaf_hash) { |
| + std::string encoded_leaf_hash = base32::Base32Encode( |
| + leaf_hash, base32::Base32EncodePolicy::OMIT_PADDING); |
| + DCHECK_EQ(encoded_leaf_hash.size(), 52u); |
| + |
| + std::string qname = base::StringPrintf( |
| + "%s.hash.%s.", encoded_leaf_hash.c_str(), domain_for_log_.data()); |
| + |
| + net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); |
| + if (factory == nullptr) { |
| + base::ThreadTaskRunnerHandle::Get()->PostTask( |
| + FROM_HERE, |
| + base::Bind(callback_, net::Error::ERR_NAME_RESOLUTION_FAILED, |
| + base::Unretained(this))); |
| + return; |
| + } |
| + |
| + net::DnsTransactionFactory::CallbackType transaction_callback = |
| + base::Bind(&LogDnsClient::AuditProofQuery::QueryLeafIndexComplete, |
| + weak_ptr_factory_.GetWeakPtr()); |
| + |
| + current_dns_transaction_ = factory->CreateTransaction( |
| + qname, net::dns_protocol::kTypeTXT, transaction_callback, net_log_); |
| + |
| + current_dns_transaction_->Start(); |
| + } |
| + |
| + void QueryLeafIndexComplete(net::DnsTransaction* transaction, |
| + int net_error, |
| + const net::DnsResponse* response) { |
| + // If we've received no response but no net::error either (shouldn't |
| + // happen), |
| + // report the response as invalid. |
| + if (response == nullptr && net_error == net::OK) { |
| + net_error = net::ERR_INVALID_RESPONSE; |
| + } |
| + |
| + if (net_error != net::OK) { |
| + base::ThreadTaskRunnerHandle::Get()->PostTask( |
| + FROM_HERE, base::Bind(callback_, net_error, base::Unretained(this))); |
| + return; |
| + } |
| + |
| + if (!ParseLeafIndex(*response, &proof_->leaf_index)) { |
| + base::ThreadTaskRunnerHandle::Get()->PostTask( |
| + FROM_HERE, base::Bind(callback_, net::ERR_DNS_MALFORMED_RESPONSE, |
| + base::Unretained(this))); |
| + return; |
| + } |
| + |
| + // Reject leaf index if it is out-of-range. |
| + // This indicates either: |
| + // a) the wrong tree_size was provided. |
| + // b) the wrong leaf hash was provided. |
| + // c) there is a bug server-side. |
| + // The first two are more likely, so return ERR_INVALID_ARGUMENT. |
| + if (proof_->leaf_index >= tree_size_) { |
| + base::ThreadTaskRunnerHandle::Get()->PostTask( |
| + FROM_HERE, base::Bind(callback_, net::ERR_INVALID_ARGUMENT, |
| + base::Unretained(this))); |
| + return; |
| + } |
| + |
| + // QueryAuditProof for the first batch of audit proof_ nodes (i.e. starting |
| + // from 0). |
| + QueryAuditProofNodes(0 /* start node index */); |
| + } |
| + |
| + // Queries a CT log to retrieve part of an audit |proof|. The |node_index| |
| + // indicates which node of the audit proof/ should be requested. The CT log |
| + // may return up to 7 nodes, starting from |node_index| (this is the maximum |
| + // that will fit in a DNS UDP packet). The nodes will be appended to |
| + // |proof->nodes|. |
| + void QueryAuditProofNodes(uint64_t node_index) { |
| + DCHECK_LT(proof_->leaf_index, tree_size_); |
| + DCHECK_LT(node_index, net::ct::CalculateAuditPathLength(proof_->leaf_index, |
| + tree_size_)); |
| + |
| + std::string qname = base::StringPrintf( |
| + "%" PRIu64 ".%" PRIu64 ".%" PRIu64 ".tree.%s.", node_index, |
| + proof_->leaf_index, tree_size_, domain_for_log_.data()); |
| + |
| + net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); |
| + if (factory == nullptr) { |
| + base::ThreadTaskRunnerHandle::Get()->PostTask( |
| + FROM_HERE, |
| + base::Bind(callback_, net::Error::ERR_NAME_RESOLUTION_FAILED, |
| + base::Unretained(this))); |
| + return; |
| + } |
| + |
| + net::DnsTransactionFactory::CallbackType transaction_callback = |
| + base::Bind(&LogDnsClient::AuditProofQuery::QueryAuditProofNodesComplete, |
| + weak_ptr_factory_.GetWeakPtr()); |
| + |
| + current_dns_transaction_ = factory->CreateTransaction( |
| + qname, net::dns_protocol::kTypeTXT, transaction_callback, net_log_); |
| + current_dns_transaction_->Start(); |
| + } |
| + |
| + void QueryAuditProofNodesComplete(net::DnsTransaction* transaction, |
| + int net_error, |
| + const net::DnsResponse* response) { |
| + // If we receive no response but no net::error either (shouldn't happen), |
| + // report the response as invalid. |
| + if (response == nullptr && net_error == net::OK) { |
| + net_error = net::ERR_INVALID_RESPONSE; |
| + } |
| + |
| + if (net_error != net::OK) { |
| + base::ThreadTaskRunnerHandle::Get()->PostTask( |
| + FROM_HERE, base::Bind(callback_, net_error, base::Unretained(this))); |
| + return; |
| + } |
| + |
| + const uint64_t audit_path_length = |
| + net::ct::CalculateAuditPathLength(proof_->leaf_index, tree_size_); |
| + // The calculated |audit_path_length| can't ever be greater than 64, so |
| + // deriving the amount of memory to reserve from the untrusted |leaf_index| |
| + // is safe. |
| + proof_->nodes.reserve(audit_path_length); |
| + |
| + if (!ParseAuditPath(*response, proof_.get())) { |
| + base::ThreadTaskRunnerHandle::Get()->PostTask( |
| + FROM_HERE, base::Bind(callback_, net::ERR_DNS_MALFORMED_RESPONSE, |
| + base::Unretained(this))); |
| + return; |
| + } |
| + |
| + const uint64_t audit_path_nodes_received = proof_->nodes.size(); |
| + if (audit_path_nodes_received < audit_path_length) { |
| + QueryAuditProofNodes(audit_path_nodes_received); |
| + return; |
| + } |
| + |
| + base::ThreadTaskRunnerHandle::Get()->PostTask( |
| + FROM_HERE, base::Bind(callback_, net::OK, base::Unretained(this))); |
| + } |
| + |
| + std::string domain_for_log_; |
| + // TODO(robpercival): Remove |tree_size| once |proof_| has a tree_size member. |
| + uint64_t tree_size_; |
| + std::unique_ptr<net::ct::MerkleAuditProof> proof_; |
| + CompletionCallback callback_; |
| + net::DnsClient* dns_client_; |
| + std::unique_ptr<net::DnsTransaction> current_dns_transaction_; |
| + net::NetLogWithSource net_log_; |
| + base::WeakPtrFactory<AuditProofQuery> weak_ptr_factory_; |
| +}; |
| + |
| LogDnsClient::LogDnsClient(std::unique_ptr<net::DnsClient> dns_client, |
| const net::NetLogWithSource& net_log, |
| size_t max_concurrent_queries) |
| : dns_client_(std::move(dns_client)), |
| net_log_(net_log), |
| @@ -103,67 +289,25 @@ void LogDnsClient::OnDNSChanged() { |
| void LogDnsClient::OnInitialDNSConfigRead() { |
| UpdateDnsConfig(); |
| } |
| -void LogDnsClient::QueryLeafIndex(base::StringPiece domain_for_log, |
| - base::StringPiece leaf_hash, |
| - const LeafIndexCallback& callback) { |
| - if (domain_for_log.empty() || leaf_hash.size() != crypto::kSHA256Length) { |
| - base::ThreadTaskRunnerHandle::Get()->PostTask( |
| - FROM_HERE, base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, 0)); |
| - return; |
| - } |
| - |
| - if (HasMaxConcurrentQueriesInProgress()) { |
| - base::ThreadTaskRunnerHandle::Get()->PostTask( |
| - FROM_HERE, |
| - base::Bind(callback, net::Error::ERR_TEMPORARILY_THROTTLED, 0)); |
| - return; |
| - } |
| - |
| - std::string encoded_leaf_hash = |
| - base32::Base32Encode(leaf_hash, base32::Base32EncodePolicy::OMIT_PADDING); |
| - DCHECK_EQ(encoded_leaf_hash.size(), 52u); |
| - |
| - net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); |
| - if (factory == nullptr) { |
| - base::ThreadTaskRunnerHandle::Get()->PostTask( |
| - FROM_HERE, |
| - base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, 0)); |
| - return; |
| - } |
| - |
| - std::string qname = base::StringPrintf( |
| - "%s.hash.%s.", encoded_leaf_hash.c_str(), domain_for_log.data()); |
| - |
| - net::DnsTransactionFactory::CallbackType transaction_callback = base::Bind( |
| - &LogDnsClient::QueryLeafIndexComplete, weak_ptr_factory_.GetWeakPtr()); |
| - |
| - std::unique_ptr<net::DnsTransaction> dns_transaction = |
| - factory->CreateTransaction(qname, net::dns_protocol::kTypeTXT, |
| - transaction_callback, net_log_); |
| - |
| - dns_transaction->Start(); |
| - leaf_index_queries_.push_back({std::move(dns_transaction), callback}); |
| -} |
| - |
| // The performance of this could be improved by sending all of the expected |
| // queries up front. Each response can contain a maximum of 7 audit path nodes, |
| // so for an audit proof of size 20, it could send 3 queries (for nodes 0-6, |
| // 7-13 and 14-19) immediately. Currently, it sends only the first and then, |
| // based on the number of nodes received, sends the next query. The complexity |
| // of the code would increase though, as it would need to detect gaps in the |
| // audit proof caused by the server not responding with the anticipated number |
| // of nodes. Ownership of the proof would need to change, as it would be shared |
| // between simultaneous DNS transactions. Throttling of queries would also need |
| // to take into account this increase in parallelism. |
| -void LogDnsClient::QueryAuditProof(base::StringPiece domain_for_log, |
| - uint64_t leaf_index, |
| +void LogDnsClient::QueryAuditProof(const std::string& domain_for_log, |
| + base::StringPiece leaf_hash, |
| uint64_t tree_size, |
| const AuditProofCallback& callback) { |
| - if (domain_for_log.empty() || leaf_index >= tree_size) { |
| + if (domain_for_log.empty() || leaf_hash.size() != crypto::kSHA256Length) { |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, |
| base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, nullptr)); |
| return; |
| } |
| @@ -173,162 +317,48 @@ void LogDnsClient::QueryAuditProof(base::StringPiece domain_for_log, |
| FROM_HERE, |
| base::Bind(callback, net::Error::ERR_TEMPORARILY_THROTTLED, nullptr)); |
| return; |
| } |
| - std::unique_ptr<net::ct::MerkleAuditProof> proof( |
| - new net::ct::MerkleAuditProof); |
| - proof->leaf_index = leaf_index; |
| - // TODO(robpercival): Once a "tree_size" field is added to MerkleAuditProof, |
| - // pass |tree_size| to QueryAuditProofNodes using that. |
| + audit_proof_queries_.emplace_back(new AuditProofQuery( |
| + dns_client_.get(), domain_for_log, tree_size, net_log_)); |
| - // Query for the first batch of audit proof nodes (i.e. starting from 0). |
| - QueryAuditProofNodes(std::move(proof), domain_for_log, tree_size, 0, |
| - callback); |
| -} |
| - |
| -void LogDnsClient::QueryLeafIndexComplete(net::DnsTransaction* transaction, |
| - int net_error, |
| - const net::DnsResponse* response) { |
| - auto query_iterator = |
| - std::find_if(leaf_index_queries_.begin(), leaf_index_queries_.end(), |
| - [transaction](const Query<LeafIndexCallback>& query) { |
| - return query.transaction.get() == transaction; |
| - }); |
| - if (query_iterator == leaf_index_queries_.end()) { |
| - NOTREACHED(); |
| - return; |
| - } |
| - const Query<LeafIndexCallback> query = std::move(*query_iterator); |
| - leaf_index_queries_.erase(query_iterator); |
| - |
| - // If we've received no response but no net::error either (shouldn't happen), |
| - // report the response as invalid. |
| - if (response == nullptr && net_error == net::OK) { |
| - net_error = net::ERR_INVALID_RESPONSE; |
| - } |
| - |
| - if (net_error != net::OK) { |
| - base::ThreadTaskRunnerHandle::Get()->PostTask( |
| - FROM_HERE, base::Bind(query.callback, net_error, 0)); |
| - return; |
| - } |
| + AuditProofQuery::CompletionCallback internal_callback = |
| + base::Bind(&LogDnsClient::QueryAuditProofComplete, |
| + weak_ptr_factory_.GetWeakPtr(), callback); |
| - uint64_t leaf_index; |
| - if (!ParseLeafIndex(*response, &leaf_index)) { |
| - base::ThreadTaskRunnerHandle::Get()->PostTask( |
| - FROM_HERE, |
| - base::Bind(query.callback, net::ERR_DNS_MALFORMED_RESPONSE, 0)); |
| - return; |
| - } |
| - |
| - base::ThreadTaskRunnerHandle::Get()->PostTask( |
| - FROM_HERE, base::Bind(query.callback, net::OK, leaf_index)); |
| + audit_proof_queries_.back()->Start(leaf_hash, internal_callback); |
| } |
| -void LogDnsClient::QueryAuditProofNodes( |
| - std::unique_ptr<net::ct::MerkleAuditProof> proof, |
| - base::StringPiece domain_for_log, |
| - uint64_t tree_size, |
| - uint64_t node_index, |
| - const AuditProofCallback& callback) { |
| - // Preconditions that should be guaranteed internally by this class. |
| - DCHECK(proof); |
| - DCHECK(!domain_for_log.empty()); |
| - DCHECK_LT(proof->leaf_index, tree_size); |
| - DCHECK_LT(node_index, |
| - net::ct::CalculateAuditPathLength(proof->leaf_index, tree_size)); |
| +void LogDnsClient::QueryAuditProofComplete(const AuditProofCallback& callback, |
| + int result, |
| + AuditProofQuery* query) { |
| + DCHECK(query); |
| - net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); |
| - if (factory == nullptr) { |
| - base::ThreadTaskRunnerHandle::Get()->PostTask( |
| - FROM_HERE, |
| - base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, nullptr)); |
| - return; |
| + std::unique_ptr<net::ct::MerkleAuditProof> proof; |
| + if (result == net::OK) { |
| + proof = query->TakeProof(); |
| } |
| - std::string qname = base::StringPrintf( |
| - "%" PRIu64 ".%" PRIu64 ".%" PRIu64 ".tree.%s.", node_index, |
| - proof->leaf_index, tree_size, domain_for_log.data()); |
| - |
| - net::DnsTransactionFactory::CallbackType transaction_callback = |
| - base::Bind(&LogDnsClient::QueryAuditProofNodesComplete, |
| - weak_ptr_factory_.GetWeakPtr(), base::Passed(std::move(proof)), |
| - domain_for_log, tree_size); |
| - |
| - std::unique_ptr<net::DnsTransaction> dns_transaction = |
| - factory->CreateTransaction(qname, net::dns_protocol::kTypeTXT, |
| - transaction_callback, net_log_); |
| - dns_transaction->Start(); |
| - audit_proof_queries_.push_back({std::move(dns_transaction), callback}); |
| -} |
| - |
| -void LogDnsClient::QueryAuditProofNodesComplete( |
| - std::unique_ptr<net::ct::MerkleAuditProof> proof, |
| - base::StringPiece domain_for_log, |
| - uint64_t tree_size, |
| - net::DnsTransaction* transaction, |
| - int net_error, |
| - const net::DnsResponse* response) { |
| - // Preconditions that should be guaranteed internally by this class. |
| - DCHECK(proof); |
| - DCHECK(!domain_for_log.empty()); |
| - |
| + // Finished with the query - destroy it. Do this before invoking |callback| |
| + // in case it wants to perform another query and |audit_proof_queries_| is |
| + // already at its limit (as specified by |max_concurrent_queries_|. |
|
Ryan Sleevi
2016/10/03 23:51:08
Because you're asynchronously invoking this callba
Rob Percival
2016/10/04 16:01:57
Done.
|
| auto query_iterator = |
| std::find_if(audit_proof_queries_.begin(), audit_proof_queries_.end(), |
| - [transaction](const Query<AuditProofCallback>& query) { |
| - return query.transaction.get() == transaction; |
| + [query](const std::unique_ptr<AuditProofQuery>& p) { |
| + return p.get() == query; |
| }); |
| - |
| - if (query_iterator == audit_proof_queries_.end()) { |
| - NOTREACHED(); |
| - return; |
| - } |
| - const Query<AuditProofCallback> query = std::move(*query_iterator); |
| + DCHECK(query_iterator != audit_proof_queries_.end()); |
| audit_proof_queries_.erase(query_iterator); |
| - // If we've received no response but no net::error either (shouldn't happen), |
| - // report the response as invalid. |
| - if (response == nullptr && net_error == net::OK) { |
| - net_error = net::ERR_INVALID_RESPONSE; |
| - } |
| - |
| - if (net_error != net::OK) { |
| - base::ThreadTaskRunnerHandle::Get()->PostTask( |
| - FROM_HERE, base::Bind(query.callback, net_error, nullptr)); |
| - return; |
| - } |
| - |
| - const uint64_t audit_path_length = |
| - net::ct::CalculateAuditPathLength(proof->leaf_index, tree_size); |
| - proof->nodes.reserve(audit_path_length); |
| - |
| - if (!ParseAuditPath(*response, proof.get())) { |
| - base::ThreadTaskRunnerHandle::Get()->PostTask( |
| - FROM_HERE, |
| - base::Bind(query.callback, net::ERR_DNS_MALFORMED_RESPONSE, nullptr)); |
| - return; |
| - } |
| - |
| - const uint64_t audit_path_nodes_received = proof->nodes.size(); |
| - if (audit_path_nodes_received < audit_path_length) { |
| - QueryAuditProofNodes(std::move(proof), domain_for_log, tree_size, |
| - audit_path_nodes_received, query.callback); |
| - return; |
| - } |
| - |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| - FROM_HERE, |
| - base::Bind(query.callback, net::OK, base::Passed(std::move(proof)))); |
| + FROM_HERE, base::Bind(callback, result, base::Passed(&proof))); |
| } |
| bool LogDnsClient::HasMaxConcurrentQueriesInProgress() const { |
| - const size_t queries_in_progress = |
| - leaf_index_queries_.size() + audit_proof_queries_.size(); |
| - |
| return max_concurrent_queries_ != 0 && |
| - queries_in_progress >= max_concurrent_queries_; |
| + audit_proof_queries_.size() >= max_concurrent_queries_; |
| } |
| void LogDnsClient::UpdateDnsConfig() { |
| net::DnsConfig config; |
| net::NetworkChangeNotifier::GetDnsConfig(&config); |