OLD | NEW |
---|---|
(Empty) | |
1 // Copyright 2016 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "components/certificate_transparency/log_dns_client.h" | |
6 | |
7 #include "base/bind.h" | |
8 #include "base/location.h" | |
9 #include "base/logging.h" | |
10 #include "base/strings/string_number_conversions.h" | |
11 #include "base/strings/string_util.h" | |
12 #include "base/threading/thread_task_runner_handle.h" | |
13 #include "base/time/time.h" | |
14 #include "components/base32/base32.h" | |
15 #include "crypto/sha2.h" | |
16 #include "net/base/net_errors.h" | |
17 #include "net/cert/merkle_audit_proof.h" | |
18 #include "net/dns/dns_client.h" | |
19 #include "net/dns/dns_protocol.h" | |
20 #include "net/dns/dns_response.h" | |
21 #include "net/dns/dns_transaction.h" | |
22 #include "net/dns/record_parsed.h" | |
23 #include "net/dns/record_rdata.h" | |
24 | |
25 namespace certificate_transparency { | |
26 | |
27 LogDnsClient::LogDnsClient(std::unique_ptr<net::DnsClient> dns_client, | |
28 const net::BoundNetLog& net_log) | |
29 : dns_client_(std::move(dns_client)), | |
30 net_log_(net_log), | |
31 weak_ptr_factory_(this) { | |
32 CHECK(dns_client_); | |
33 } | |
34 | |
35 LogDnsClient::~LogDnsClient() {} | |
36 | |
37 void LogDnsClient::QueryLeafIndex(base::StringPiece domain_for_log, | |
38 base::StringPiece leaf_hash, | |
39 const LeafIndexCallback& callback) { | |
40 if (leaf_hash.size() != crypto::kSHA256Length) { | |
41 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
42 FROM_HERE, base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, 0)); | |
43 return; | |
44 } | |
45 | |
46 std::string encoded_leaf_hash = | |
47 base32::Base32Encode(leaf_hash, base32::Base32EncodePolicy::OMIT_PADDING); | |
48 DCHECK_EQ(encoded_leaf_hash.size(), 52u); | |
49 std::string qname = encoded_leaf_hash; | |
Eran Messeri
2016/07/01 09:08:00
Nit: Use StringPrintf
Rob Percival
2016/07/01 16:01:01
I've opted for using std::ostringstream instead, a
| |
50 qname += ".hash."; | |
51 domain_for_log.AppendToString(&qname); | |
52 qname += "."; | |
53 | |
54 net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); | |
55 if (factory == nullptr) { | |
56 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
57 FROM_HERE, | |
58 base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, 0)); | |
59 return; | |
60 } | |
61 | |
62 net::DnsTransactionFactory::CallbackType transaction_callback = base::Bind( | |
63 &LogDnsClient::QueryLeafIndexComplete, weak_ptr_factory_.GetWeakPtr()); | |
64 | |
65 std::unique_ptr<net::DnsTransaction> dns_transaction = | |
66 factory->CreateTransaction(qname, net::dns_protocol::kTypeTXT, | |
67 transaction_callback, net_log_); | |
68 | |
69 dns_transaction->Start(); | |
70 leaf_index_queries_.push_back({std::move(dns_transaction), callback}); | |
71 } | |
72 | |
73 // The performance of this could be improved by sending all of the expected | |
74 // queries up front. Each response can contain a maximum of 7 audit path nodes, | |
75 // so for an audit proof of size 20, it could send 3 queries (for nodes 0-6, | |
76 // 7-13 and 14-19) immediately. Currently, it sends only the first and then, | |
77 // based on the number of nodes received, sends the next query. The complexity | |
78 // of the code would increase though, as it would need to detect gaps in the | |
79 // audit proof caused by the server not responding with the anticipated number | |
80 // of nodes. Ownership of the proof would need to change, as it would be shared | |
81 // between simultaneous DNS transactions. | |
82 void LogDnsClient::QueryAuditProof(base::StringPiece domain_for_log, | |
83 uint64_t leaf_index, | |
84 uint64_t tree_size, | |
85 const AuditProofCallback& callback) { | |
86 if (leaf_index >= tree_size) { | |
87 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
88 FROM_HERE, | |
89 base::Bind(callback, net::Error::ERR_INVALID_ARGUMENT, nullptr)); | |
90 return; | |
91 } | |
92 | |
93 std::unique_ptr<net::ct::MerkleAuditProof> proof( | |
94 new net::ct::MerkleAuditProof); | |
95 proof->leaf_index = leaf_index; | |
96 | |
97 QueryAuditProofNodes(std::move(proof), domain_for_log, leaf_index, tree_size, | |
Eran Messeri
2016/07/01 09:07:59
You could drop the leaf_index from the parameters
Rob Percival
2016/07/01 16:01:01
Done.
| |
98 0, callback); | |
Eran Messeri
2016/07/01 09:08:00
comment what the 0 stands for, e.g.:
0 /* first pr
Rob Percival
2016/07/01 16:01:01
Done.
| |
99 } | |
100 | |
101 void LogDnsClient::QueryLeafIndexComplete(net::DnsTransaction* transaction, | |
102 int net_error, | |
103 const net::DnsResponse* response) { | |
104 auto query_iterator = | |
105 std::find_if(leaf_index_queries_.begin(), leaf_index_queries_.end(), | |
106 [transaction](const Query<LeafIndexCallback>& query) { | |
107 return query.transaction.get() == transaction; | |
108 }); | |
109 if (query_iterator == leaf_index_queries_.end()) { | |
110 NOTREACHED(); | |
111 return; | |
112 } | |
113 const Query<LeafIndexCallback> query = std::move(*query_iterator); | |
114 leaf_index_queries_.erase(query_iterator); | |
115 | |
116 if (net_error != net::OK) { | |
117 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
118 FROM_HERE, base::Bind(query.callback, net_error, 0)); | |
119 return; | |
120 } | |
121 | |
122 if (response == nullptr) { | |
Eran Messeri
2016/07/01 09:08:00
Nit: You could move this condition up (to line 115
Rob Percival
2016/07/01 16:01:01
That could hide a more specific net_error though.
| |
123 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
124 FROM_HERE, base::Bind(query.callback, net::ERR_INVALID_RESPONSE, 0)); | |
125 return; | |
126 } | |
127 | |
128 uint64_t leaf_index; | |
129 if (!ParseLeafIndex(*response, &leaf_index)) { | |
130 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
131 FROM_HERE, | |
132 base::Bind(query.callback, net::ERR_DNS_MALFORMED_RESPONSE, 0)); | |
133 return; | |
134 } | |
135 | |
136 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
137 FROM_HERE, base::Bind(query.callback, net::OK, leaf_index)); | |
138 } | |
139 | |
140 void LogDnsClient::QueryAuditProofNodes( | |
141 std::unique_ptr<net::ct::MerkleAuditProof> proof, | |
142 base::StringPiece domain_for_log, | |
143 uint64_t leaf_index, | |
144 uint64_t tree_size, | |
145 uint64_t node_index, | |
146 const AuditProofCallback& callback) { | |
147 CHECK(proof); | |
148 CHECK_LT(leaf_index, tree_size); | |
149 CHECK_LT(node_index, | |
150 net::ct::CalculateAuditPathLength(leaf_index, tree_size)); | |
151 | |
152 std::string qname = base::Uint64ToString(node_index); | |
153 qname += "."; | |
154 qname += base::Uint64ToString(leaf_index); | |
155 qname += "."; | |
156 qname += base::Uint64ToString(tree_size); | |
157 qname += ".tree."; | |
158 domain_for_log.AppendToString(&qname); | |
159 qname += "."; | |
160 | |
161 net::DnsTransactionFactory* factory = dns_client_->GetTransactionFactory(); | |
162 if (factory == nullptr) { | |
163 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
164 FROM_HERE, | |
165 base::Bind(callback, net::Error::ERR_NAME_RESOLUTION_FAILED, nullptr)); | |
166 return; | |
167 } | |
168 | |
169 net::DnsTransactionFactory::CallbackType transaction_callback = | |
170 base::Bind(&LogDnsClient::QueryAuditProofNodesComplete, | |
171 weak_ptr_factory_.GetWeakPtr(), base::Passed(std::move(proof)), | |
172 domain_for_log, leaf_index, tree_size); | |
173 | |
174 std::unique_ptr<net::DnsTransaction> dns_transaction = | |
175 factory->CreateTransaction(qname, net::dns_protocol::kTypeTXT, | |
176 transaction_callback, net_log_); | |
177 dns_transaction->Start(); | |
178 audit_proof_queries_.push_back({std::move(dns_transaction), callback}); | |
179 } | |
180 | |
181 void LogDnsClient::QueryAuditProofNodesComplete( | |
182 std::unique_ptr<net::ct::MerkleAuditProof> proof, | |
183 base::StringPiece domain_for_log, | |
184 uint64_t leaf_index, | |
185 uint64_t tree_size, | |
186 net::DnsTransaction* transaction, | |
187 int net_error, | |
188 const net::DnsResponse* response) { | |
189 auto query_iterator = | |
190 std::find_if(audit_proof_queries_.begin(), audit_proof_queries_.end(), | |
191 [transaction](const Query<AuditProofCallback>& query) { | |
192 return query.transaction.get() == transaction; | |
193 }); | |
194 | |
195 if (query_iterator == audit_proof_queries_.end()) { | |
196 NOTREACHED(); | |
197 return; | |
198 } | |
199 const Query<AuditProofCallback> query = std::move(*query_iterator); | |
200 audit_proof_queries_.erase(query_iterator); | |
201 | |
202 if (net_error != net::OK) { | |
203 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
204 FROM_HERE, base::Bind(query.callback, net_error, nullptr)); | |
205 return; | |
206 } | |
207 | |
208 if (response == nullptr) { | |
209 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
210 FROM_HERE, | |
211 base::Bind(query.callback, net::ERR_INVALID_RESPONSE, nullptr)); | |
212 return; | |
213 } | |
214 | |
215 const uint64_t audit_path_length = | |
216 net::ct::CalculateAuditPathLength(leaf_index, tree_size); | |
217 proof->nodes.reserve(audit_path_length); | |
218 | |
219 if (!ParseAuditPath(*response, proof.get())) { | |
220 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
221 FROM_HERE, | |
222 base::Bind(query.callback, net::ERR_DNS_MALFORMED_RESPONSE, nullptr)); | |
223 return; | |
224 } | |
225 | |
226 const uint64_t audit_path_nodes_received = proof->nodes.size(); | |
227 if (audit_path_nodes_received < audit_path_length) { | |
228 QueryAuditProofNodes(std::move(proof), domain_for_log, leaf_index, | |
229 tree_size, audit_path_nodes_received, query.callback); | |
230 return; | |
231 } | |
232 | |
233 base::ThreadTaskRunnerHandle::Get()->PostTask( | |
234 FROM_HERE, | |
235 base::Bind(query.callback, net::OK, base::Passed(std::move(proof)))); | |
236 } | |
237 | |
238 bool LogDnsClient::ParseTxtResponse(const net::DnsResponse& response, | |
239 std::string* txt) { | |
240 DCHECK(txt); | |
241 | |
242 net::DnsRecordParser parser = response.Parser(); | |
243 // We don't care about the creation time, since we're going to throw | |
244 // |parsed_record| away as soon as we've extracted the payload, so provide | |
245 // the "null" time. | |
246 auto parsed_record = net::RecordParsed::CreateFrom(&parser, base::Time()); | |
247 if (parsed_record == nullptr) | |
248 return false; | |
249 | |
250 auto txt_record = parsed_record->rdata<net::TxtRecordRdata>(); | |
251 if (txt_record == nullptr) | |
252 return false; | |
253 | |
254 *txt = base::JoinString(txt_record->texts(), ""); | |
255 return true; | |
256 } | |
257 | |
258 bool LogDnsClient::ParseLeafIndex(const net::DnsResponse& response, | |
259 uint64_t* index) { | |
260 std::string index_str; | |
261 if (!ParseTxtResponse(response, &index_str)) | |
262 return false; | |
263 | |
264 return base::StringToUint64(index_str, index); | |
265 } | |
266 | |
267 bool LogDnsClient::ParseAuditPath(const net::DnsResponse& response, | |
268 net::ct::MerkleAuditProof* proof) { | |
269 std::string audit_path; | |
270 if (!ParseTxtResponse(response, &audit_path)) | |
Eran Messeri
2016/07/01 09:07:59
Also return false if the audit_path is of length 0
Rob Percival
2016/07/01 16:01:01
Done.
| |
271 return false; | |
272 | |
273 for (size_t i = 0; i < audit_path.size(); i += crypto::kSHA256Length) { | |
274 std::string node = audit_path.substr(i, crypto::kSHA256Length); | |
275 | |
276 if (node.size() == crypto::kSHA256Length) | |
Eran Messeri
2016/07/01 09:08:00
You could check beforehand:
if (audit_path.size()
Rob Percival
2016/07/01 16:01:01
Done.
| |
277 proof->nodes.push_back(node); | |
278 else | |
279 return false; | |
280 } | |
281 | |
282 return true; | |
283 } | |
284 | |
285 } // namespace certificate_transparency | |
OLD | NEW |