Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(252)

Side by Side Diff: net/dns/mdns_listener_impl.cc

Issue 15733008: Multicast DNS implementation (initial) (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@mdns_implementation2
Patch Set: Created 7 years, 7 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
(Empty)
1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "net/dns/mdns_listener_impl.h"
6
7 #include "base/bind.h"
8 #include "base/time/default_clock.h"
9 #include "base/message_loop_proxy.h"
10 #include "net/base/dns_util.h"
11 #include "net/base/net_errors.h"
12 #include "net/base/net_log.h"
13 #include "net/base/rand_callback.h"
14 #include "net/dns/dns_protocol.h"
15 #include "net/dns/mdns_query.h"
16 #include "net/udp/datagram_socket.h"
17
18 namespace net {
19
20 static const unsigned kMDnsTransactionTimeoutSeconds = 3;
21
22 MDnsListenerFactoryImpl::Core::Core(MDnsListenerFactoryImpl* factory)
23 : socket_ipv4_(new UDPSocket(DatagramSocket::DEFAULT_BIND,
24 RandIntCallback(),
25 NULL, NetLog::Source())),
26 socket_ipv6_(new UDPSocket(DatagramSocket::DEFAULT_BIND,
27 RandIntCallback(),
28 NULL, NetLog::Source())),
29 response_ipv4_(new DnsResponse(dns_protocol::kMaxMulticastSize)),
30 response_ipv6_(new DnsResponse(dns_protocol::kMaxMulticastSize)),
31 factory_(factory) {
32 }
33
34 MDnsListenerFactoryImpl::Core::~Core() {
35 cleanup_callback_.Cancel();
36 socket_ipv4_->Close();
37 socket_ipv6_->Close();
38 }
39
40 bool MDnsListenerFactoryImpl::Core::Init() {
41 if (!BindSocket(socket_ipv4_.get(), kIPv4AddressSize,
42 kMDNSMulticastGroupIPv4))
43 return false;
44
45 if (!BindSocket(socket_ipv6_.get(), kIPv6AddressSize,
46 kMDNSMulticastGroupIPv6))
47 return false;
48
49 if (!RecieveOnePacket(socket_ipv4_.get(),
szym 2013/05/24 15:32:22 What is the point of doing this?
Noam Samuel 2013/05/24 19:00:49 Oh. I guess this method is somewhat misnamed. Reci
Noam Samuel 2013/05/29 21:25:16 Renamed method and moved it to MDnsConnectionImpl.
50 response_ipv4_.get(),
51 &recv_addr_ipv4_))
52 return false;
53
54 if (!RecieveOnePacket(socket_ipv6_.get(),
55 response_ipv6_.get(),
56 &recv_addr_ipv6_))
57 return false;
58
59 return true;
60 }
61
62 bool MDnsListenerFactoryImpl::Core::SendQuery(uint16 rrtype, std::string name) {
63 std::string name_dns;
64 if (!DNSDomainFromDot(name, &name_dns))
65 return false;
66
67 MDnsQuery query(name_dns, rrtype);
68
69 int rv = socket_ipv4_->SendTo(
70 query.io_buffer(),
71 query.size(),
72 factory_->GetIPv4SendEndpoint(),
73 base::Bind(&MDnsListenerFactoryImpl::Core::SendDone,
74 AsWeakPtr()));
75 if (rv < OK && rv != ERR_IO_PENDING) return false;
76
77 rv = socket_ipv6_->SendTo(
78 query.io_buffer(),
79 query.size(),
80 factory_->GetIPv6SendEndpoint(),
81 base::Bind(&MDnsListenerFactoryImpl::Core::SendDone,
82 AsWeakPtr()));
83 if (rv < OK && rv != ERR_IO_PENDING) return false;
84
85 return true;
86 }
87
88 void MDnsListenerFactoryImpl::Core::SendDone(int sent) {
89 // TODO(noamsml): Queueing and retry logic
90 }
91
92 bool MDnsListenerFactoryImpl::Core::BindSocket(
93 UDPSocket* socket,
94 int addr_size,
95 const char* multicast_group) {
96 IPAddressNumber address_any;
97 address_any.resize(addr_size, 0);
98
99 IPAddressNumber multicast_group_number;
100
101 IPEndPoint bind_endpoint(address_any, dns_protocol::kDefaultPortMulticast);
102
103 bool success = ParseIPLiteralToNumber(multicast_group,
104 &multicast_group_number);
105 DCHECK(success);
106
107 return factory_->BindToAddressAndMulticastGroup(
108 socket, bind_endpoint, multicast_group_number);
109 }
110
111 bool MDnsListenerFactoryImpl::Core::RecieveOnePacket(
112 UDPSocket* socket,
113 DnsResponse* response,
114 IPEndPoint* recv_addr) {
115 int rval = socket->RecvFrom(
116 response->io_buffer(),
117 response->io_buffer()->size(),
118 recv_addr,
119 base::Bind(&MDnsListenerFactoryImpl::Core::OnDatagramRecieved,
120 AsWeakPtr(), socket, response, recv_addr));
szym 2013/05/25 02:51:01 You don't really need WeakPtr here since if Core i
Noam Samuel 2013/05/29 21:25:16 Done.
121
122 if (rval > 0) {
123 factory_->task_runner_->PostTask(
124 FROM_HERE,
125 base::Bind(&MDnsListenerFactoryImpl::Core::OnDatagramRecieved,
szym 2013/05/25 02:51:01 Why not call synchronously?
Noam Samuel 2013/05/29 17:53:03 If an attacker were to spam the mdns port with pac
szym 2013/05/29 17:59:41 The solution to this is to replace recursion with
Noam Samuel 2013/05/29 21:25:16 Done.
126 AsWeakPtr(),
127 socket,
128 response,
129 recv_addr,
130 rval));
131 } else {
132 if (rval != ERR_IO_PENDING) return false;
133 }
134
135 return true;
136 }
137
138 void MDnsListenerFactoryImpl::Core::OnDatagramRecieved(
139 UDPSocket* socket,
140 DnsResponse* response,
141 IPEndPoint* recv_addr,
142 int bytes_read) {
143 // TODO(noamsml): More sophisticated error handling.
144 DCHECK_GT(bytes_read, 0);
145 HandlePacket(response, bytes_read);
146 bool success = RecieveOnePacket(socket, response, recv_addr);
147
148 DCHECK(success); // TODO(noamsml): exponential backoff.
149 }
150
151 void MDnsListenerFactoryImpl::Core::HandlePacket(DnsResponse* response,
152 int bytes_read) {
153 unsigned offset;
154
155 if (!response->InitParseWithoutQuery(bytes_read)) {
156 LOG(WARNING) << "Could not understand an mDNS packet.";
157 return; // Message is unreadable.
158 }
159
160 // TODO(noamsml): duplicate query suppression.
161 if (!(response->flags() & dns_protocol::kFlagResponse)) {
162 return; // Message is a query. ignore it.
163 }
164 DnsRecordParser parser = response->Parser();
165 unsigned answer_count = response->answer_count() +
166 response->additional_answer_count();
167
168 for (unsigned i = 0; i < answer_count; i++) {
169 offset = parser.GetOffset();
170 scoped_ptr<const RecordParsed> scoped_record = RecordParsed::CreateFrom(
171 &parser, factory_->clock_->Now());
172
173 if (!scoped_record) {
174 LOG(WARNING) << "Could not understand an mDNS record.";
175
176 if (offset == parser.GetOffset()) {
177 LOG(WARNING) << "Abandoned parsing the rest of the packet.";
178 return; // The parser did not advance, abort reading the packet.
179 } else {
180 continue; // We may be able to extract other records from the packet.
181 }
182 }
183
184 if ((scoped_record->klass() & dns_protocol::kMDnsClassMask) !=
185 dns_protocol::kClassIN) {
186 LOG(WARNING) << "Recieved an mDNS record with non-IN class. Ignoring.";
187 continue; // Ignore all records not in the IN class.
188 }
189
190 // We want to retain a copy of the record pointer for updating listeners
191 // but we are passing ownership to the cache.
192 const RecordParsed* record = scoped_record.get();
193 MDnsCache::UpdateType update = cache_.UpdateDnsRecord(scoped_record.Pass());
194
195 // Cleanup time may have changed.
196 ScheduleCleanup(cache_.next_expiration());
197
198 if (update != MDnsCache::NoChange) {
199 MDnsUpdateType update_external;
200
201 switch (update) {
202 case MDnsCache::RecordAdded:
203 update_external = kMDnsRecordAdded;
204 break;
205 case MDnsCache::RecordChanged:
206 update_external = kMDnsRecordChanged;
207 break;
208 case MDnsCache::NoChange:
209 NOTREACHED();
210 // Dummy assignment to suppress compiler warning.
211 update_external = kMDnsRecordChanged;
212 break;
213 }
214
215 AlertListeners(update_external,
216 ListenerKey(record->type(), record->name()), record);
217 // Alert listeners listening only for rrtype and not for name.
218 AlertListeners(update_external, ListenerKey(record->type(), ""), record);
219 }
220 }
221 }
222
223 void MDnsListenerFactoryImpl::Core::AlertListeners(
224 MDnsUpdateType update_type,
225 const ListenerKey& key,
226 const RecordParsed* record) {
227 std::pair<ListenersIterator, ListenersIterator> listeners =
228 listeners_.equal_range(key);
229 for (ListenersIterator l = listeners.first; l != listeners.second; ++l) {
230 scoped_ptr<const RecordParsed> record_clone = record->Clone();
231 factory_->task_runner_->PostTask(FROM_HERE, base::Bind(
232 &MDnsListenerImpl::AlertDelegate, l->second->AsWeakPtr(), update_type,
233 base::Owned(record_clone.release())) );
szym 2013/05/24 15:32:22 Why not just call: l->second->AlertDelegate(update
Noam Samuel 2013/05/24 19:00:49 Listener delegates may create or delete any listen
Noam Samuel 2013/05/24 22:34:20 On second thought, deferring shutting down listeni
szym 2013/05/25 02:51:01 You could consider using std::map<ListenerKey, Obs
Noam Samuel 2013/05/29 17:53:03 Taking a scoped reference in the class entrypoints
szym 2013/05/29 17:59:41 Ok, but now you have to deal with a situation wher
Noam Samuel 2013/05/29 18:35:35 Each individual listener is added/removed in an RA
szym 2013/05/29 18:40:43 With deferred deletion, how do you avoid calling b
Noam Samuel 2013/05/29 20:14:49 Oh, sorry. The deferred deletion applies only to t
234 }
235 }
236
237 void MDnsListenerFactoryImpl::Core::AddListener(
238 MDnsListenerImpl* listener, bool alert_existing_records) {
239 ListenerKey key(listener->GetType(), listener->GetName());
240
241 listeners_.insert(std::pair<ListenerKey, MDnsListenerImpl*>(key, listener));
242
243 if (alert_existing_records) {
244 std::vector<const RecordParsed*> records;
245
246 cache_.FindDnsRecords(listener->GetType(), listener->GetName(),
247 &records, factory_->clock_->Now());
248
249 for (std::vector<const RecordParsed*>::iterator i = records.begin();
250 i != records.end(); i++) {
251 scoped_ptr<const RecordParsed> record_clone = (*i)->Clone();
252 factory_->task_runner_->PostTask(FROM_HERE, base::Bind(
253 &MDnsListenerImpl::AlertDelegate, listener->AsWeakPtr(),
254 kMDnsRecordAdded, base::Owned(record_clone.release())) );
255 }
256 }
257 }
258
259 void MDnsListenerFactoryImpl::Core::RemoveListener(MDnsListenerImpl* listener) {
260 ListenerKey key(listener->GetType(), listener->GetName());
261
262 bool removed = false;
263
264 std::pair<ListenersIterator, ListenersIterator> listeners =
265 listeners_.equal_range(key);
266 for (ListenersIterator l = listeners.first; l != listeners.second;
267 l++) {
268 if (l->second == listener) {
269 listeners_.erase(l);
270 removed = true;
271 break;
272 }
273 }
274 DCHECK(removed);
275
276 // When we remove a listener, we notify factory that a listen reference has
277 // been removed. This may cause the core to be deleted.
278 factory_->SubtractListenRef();
279 }
280
281 void MDnsListenerFactoryImpl::Core::ScheduleCleanup(base::Time cleanup) {
282 // Cleanup is already scheduled, no need to do anything.
283 if (cleanup == scheduled_cleanup_) return;
284 scheduled_cleanup_ = cleanup;
285
286 // This line has the effect of cancelling the previously scheduled cleanup.
287 cleanup_callback_.Reset(base::Bind(
288 &MDnsListenerFactoryImpl::Core::DoCleanup, base::Unretained(this)));
289
290 // cleanup == base::Time means no cleanup necessary.
291 if (cleanup != base::Time()) {
292 factory_->task_runner_->PostDelayedTask(
293 FROM_HERE,
294 cleanup_callback_.callback(),
295 cleanup - factory_->clock_->Now());
296 }
297 }
298
299 void MDnsListenerFactoryImpl::Core::DoCleanup() {
300 cache_.CleanupRecords(factory_->clock_->Now(), base::Bind(
301 &MDnsListenerFactoryImpl::Core::OnRecordRemoved, base::Unretained(this)));
302
303 ScheduleCleanup(cache_.next_expiration());
304 }
305
306 void MDnsListenerFactoryImpl::Core::OnRecordRemoved(
307 const RecordParsed* record) {
308 AlertListeners(kMDnsRecordRemoved,
309 ListenerKey(record->type(), record->name()), record);
310 // Alert listeners listening only for rrtype and not for name.
311 AlertListeners(kMDnsRecordRemoved, ListenerKey(record->type(), ""),
312 record);
313 }
314
315 void MDnsListenerFactoryImpl::Core::QueryCache(
316 uint16 rrtype, const std::string& name,
317 std::vector<const RecordParsed*>* records) const {
318 cache_.FindDnsRecords(rrtype, name, records, factory_->clock_->Now());
319 }
320
321 MDnsListenerFactoryImpl::MDnsListenerFactoryImpl()
322 : listen_refs_(0), clock_owned_(new base::DefaultClock()),
323 task_runner_(base::MessageLoopProxy::current()) {
324 clock_ = clock_owned_.get();
325 }
326
327 MDnsListenerFactoryImpl::MDnsListenerFactoryImpl(base::Clock* clock,
328 base::TaskRunner* task_runner)
329 : listen_refs_(0), clock_(clock), task_runner_(task_runner) {
330 }
331
332 MDnsListenerFactoryImpl::~MDnsListenerFactoryImpl() {
333 }
334
335 bool MDnsListenerFactoryImpl::AddListenRef() {
336 if (listen_refs_ == 0) {
337 core_.reset(new Core(this));
338 if (!core_->Init()) {
339 core_.reset();
340 return false;
341 }
342 }
343 listen_refs_++;
344 return true;
345 }
346
347 void MDnsListenerFactoryImpl::SubtractListenRef() {
348 listen_refs_--;
349 if (listen_refs_ == 0) core_.reset(NULL);
350 }
351
352 bool MDnsListenerFactoryImpl::BindToAddressAndMulticastGroup(
353 UDPSocket* socket,
354 const IPEndPoint& address_bind,
355 const IPAddressNumber& group_multicast) {
356 socket->AllowAddressReuse();
357 int status = socket->Bind(address_bind);
358
359 if (status < 0)
360 return false;
361
362 socket->SetMulticastLoopbackMode(false);
363
364 status = socket->JoinGroup(group_multicast);
365
366 if (status < 0)
367 return false;
368
369 return true;
370 }
371
372 bool MDnsListenerFactoryImpl::IsListeningForTests() {
373 return core_.get() != NULL;
374 }
375
376 IPEndPoint MDnsListenerFactoryImpl::GetIPv4SendEndpoint() {
377 IPAddressNumber multicast_group_number;
378 bool success = ParseIPLiteralToNumber(kMDNSMulticastGroupIPv4,
379 &multicast_group_number);
380 DCHECK(success);
381 return IPEndPoint(multicast_group_number,
382 dns_protocol::kDefaultPortMulticast);
383 }
384
385 IPEndPoint MDnsListenerFactoryImpl::GetIPv6SendEndpoint() {
386 IPAddressNumber multicast_group_number;
387 bool success = ParseIPLiteralToNumber(kMDNSMulticastGroupIPv6,
388 &multicast_group_number);
389 DCHECK(success);
390 return IPEndPoint(multicast_group_number,
391 dns_protocol::kDefaultPortMulticast);
392 }
393
394 scoped_ptr<MDnsListener> MDnsListenerFactoryImpl::CreateListener(
395 uint16 rrtype,
396 const std::string& name,
397 bool active,
398 bool alert_existing_records,
399 MDnsListenerFactory::Delegate* delegate) {
400 if (!AddListenRef()) return scoped_ptr<net::MDnsListener>();
401
402 return scoped_ptr<net::MDnsListener>(
403 new MDnsListenerImpl(rrtype, name, active, alert_existing_records,
404 delegate, core_.get()));
405 }
406
407 scoped_ptr<MDnsTransaction> MDnsListenerFactoryImpl::CreateTransaction(
408 uint16 rrtype,
409 const std::string& name,
410 const MDnsListenerFactory::QueryCallback& callback) {
411 scoped_ptr<MDnsTransactionImpl> transaction(
412 new MDnsTransactionImpl(rrtype, name, callback, task_runner_));
413
414 if (transaction->Init(this, core_.get())) {
415 return scoped_ptr<MDnsTransaction>(transaction.Pass());
416 } else {
417 return scoped_ptr<MDnsTransaction>();
418 }
419 }
420
421 MDnsListenerImpl::MDnsListenerImpl(
422 uint16 rrtype,
423 const std::string& name,
424 bool active,
425 bool alert_existing_records,
426 MDnsListenerFactory::Delegate* delegate,
427 MDnsListenerFactoryImpl::Core* core)
428 : rrtype_(rrtype), name_(name), active_(active),
429 parent_(core), delegate_(delegate) {
430 parent_->AddListener(this, alert_existing_records);
431
432 if (active) SendQuery(false); // TODO(noamsml): Retry logic.
433 }
434
435 MDnsListenerImpl::~MDnsListenerImpl() {
436 parent_->RemoveListener(this);
437 }
438
439 const std::string& MDnsListenerImpl::GetName() const {
440 return name_;
441 }
442
443 uint16 MDnsListenerImpl::GetType() const {
444 return rrtype_;
445 }
446
447 bool MDnsListenerImpl::IsActive() const {
448 return active_;
449 }
450
451 bool MDnsListenerImpl::SendQuery(bool force_refresh_cache) {
452 // TODO(noamsml): Logic related to force_refresh_cache
453 if (name_.size() == 0) return false;
454 return parent_->SendQuery(rrtype_, name_);
455 }
456
457 bool MDnsListenerImpl::QueryCache(
458 std::vector<const RecordParsed*>* records) const {
459 if (name_.size() == 0) return false;
460 parent_->QueryCache(rrtype_, name_, records);
461 return true;
462 }
463
464 void MDnsListenerImpl::AlertDelegate(MDnsUpdateType update_type,
465 const RecordParsed* record) {
466 delegate_->OnRecordUpdate(update_type, record);
467 }
468
469 MDnsTransactionImpl::MDnsTransactionImpl(
470 uint16 rrtype,
471 const std::string& name,
472 const MDnsListenerFactory::QueryCallback& callback,
473 base::TaskRunner* task_runner)
474 : rrtype_(rrtype), name_(name), callback_(callback), triggered_(false),
475 task_runner_(task_runner) {
476 }
477
478 MDnsTransactionImpl::~MDnsTransactionImpl() {
479 }
480
481 bool MDnsTransactionImpl::Init(
482 MDnsListenerFactoryImpl* factory,
483 MDnsListenerFactoryImpl::Core* core) {
484 DCHECK(factory);
485 std::vector<const RecordParsed*> records;
486 if (core) {
487 core->QueryCache(rrtype_, name_, &records);
488 if (!records.empty()) {
489 scoped_ptr<const RecordParsed> record_clone = records.front()->Clone();
490 task_runner_->PostTask(
491 FROM_HERE,
492 base::Bind(&MDnsTransactionImpl::CacheRecordFound,
493 AsWeakPtr(), base::Owned(
494 record_clone.release())) );
495
496 return true;
497 }
498 }
499
500 listener_ = factory->CreateListener(rrtype_, name_, true /*active*/,
501 false /*alert existing*/, this);
502
503 timeout_.Reset(base::Bind(&MDnsTransactionImpl::OnTimedOut, AsWeakPtr()));
504 task_runner_->PostDelayedTask(
505 FROM_HERE,
506 timeout_.callback(),
507 base::TimeDelta::FromSeconds(kMDnsTransactionTimeoutSeconds));
508
509 return listener_.get() != NULL;
510 }
511
512 const std::string& MDnsTransactionImpl::GetName() const {
513 return name_;
514 }
515
516 uint16 MDnsTransactionImpl::GetType() const {
517 return rrtype_;
518 }
519
520 void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) {
521 OnRecordUpdate(kMDnsRecordAdded, record);
522 }
523
524 void MDnsTransactionImpl::TriggerCallback(MDnsTransactionResult result,
525 const RecordParsed* record) {
526 if (triggered_) return;
527 triggered_ = true;
528
529 // Ensure callback is run after touching all class state, so that
530 // the callback can delete the transaction.
531 MDnsListenerFactory::QueryCallback callback = callback_;
532
533 callback_.Reset();
534 listener_.reset();
535 timeout_.Cancel();
536
537 callback.Run(result, record);
538 }
539
540 void MDnsTransactionImpl::OnRecordUpdate(MDnsUpdateType update,
541 const RecordParsed* record) {
542 if (update == kMDnsRecordAdded || update == kMDnsRecordChanged) {
543 TriggerCallback(kMDnsTransactionSuccess, record);
544 }
545 }
546
547 void MDnsTransactionImpl::OnTimedOut() {
548 TriggerCallback(kMDnsTransactionTimeout, NULL);
549 }
550
551 void MDnsTransactionImpl::OnNSecRecord(const std::string& name, unsigned type) {
552 // TODO(noamsml): NSEC records not yet implemented
553 }
554
555 } // namespace net
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698