OLD | NEW |
1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/dns/mdns_client_impl.h" | 5 #include "net/dns/mdns_client_impl.h" |
6 | 6 |
| 7 #include <algorithm> |
7 #include <queue> | 8 #include <queue> |
8 | 9 |
9 #include "base/bind.h" | 10 #include "base/bind.h" |
10 #include "base/message_loop/message_loop_proxy.h" | 11 #include "base/message_loop/message_loop_proxy.h" |
11 #include "base/stl_util.h" | 12 #include "base/stl_util.h" |
| 13 #include "base/time/clock.h" |
12 #include "base/time/default_clock.h" | 14 #include "base/time/default_clock.h" |
13 #include "base/time/time.h" | 15 #include "base/time/time.h" |
| 16 #include "base/timer/timer.h" |
14 #include "net/base/dns_util.h" | 17 #include "net/base/dns_util.h" |
15 #include "net/base/net_errors.h" | 18 #include "net/base/net_errors.h" |
16 #include "net/base/net_log.h" | 19 #include "net/base/net_log.h" |
17 #include "net/base/rand_callback.h" | 20 #include "net/base/rand_callback.h" |
18 #include "net/dns/dns_protocol.h" | 21 #include "net/dns/dns_protocol.h" |
19 #include "net/dns/record_rdata.h" | 22 #include "net/dns/record_rdata.h" |
20 #include "net/udp/datagram_socket.h" | 23 #include "net/udp/datagram_socket.h" |
21 | 24 |
22 // TODO(gene): Remove this temporary method of disabling NSEC support once it | 25 // TODO(gene): Remove this temporary method of disabling NSEC support once it |
23 // becomes clear whether this feature should be | 26 // becomes clear whether this feature should be |
(...skipping 49 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
73 multicast_addr_ = GetMDnsIPEndPoint(end_point.GetFamily()); | 76 multicast_addr_ = GetMDnsIPEndPoint(end_point.GetFamily()); |
74 return DoLoop(0); | 77 return DoLoop(0); |
75 } | 78 } |
76 | 79 |
77 int MDnsConnection::SocketHandler::DoLoop(int rv) { | 80 int MDnsConnection::SocketHandler::DoLoop(int rv) { |
78 do { | 81 do { |
79 if (rv > 0) | 82 if (rv > 0) |
80 connection_->OnDatagramReceived(&response_, recv_addr_, rv); | 83 connection_->OnDatagramReceived(&response_, recv_addr_, rv); |
81 | 84 |
82 rv = socket_->RecvFrom( | 85 rv = socket_->RecvFrom( |
83 response_.io_buffer(), | 86 response_.io_buffer(), response_.io_buffer()->size(), &recv_addr_, |
84 response_.io_buffer()->size(), | |
85 &recv_addr_, | |
86 base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived, | 87 base::Bind(&MDnsConnection::SocketHandler::OnDatagramReceived, |
87 base::Unretained(this))); | 88 base::Unretained(this))); |
88 } while (rv > 0); | 89 } while (rv > 0); |
89 | 90 |
90 if (rv != ERR_IO_PENDING) | 91 if (rv != ERR_IO_PENDING) |
91 return rv; | 92 return rv; |
92 | 93 |
93 return OK; | 94 return OK; |
94 } | 95 } |
95 | 96 |
(...skipping 92 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
188 | 189 |
189 void MDnsConnection::OnDatagramReceived( | 190 void MDnsConnection::OnDatagramReceived( |
190 DnsResponse* response, | 191 DnsResponse* response, |
191 const IPEndPoint& recv_addr, | 192 const IPEndPoint& recv_addr, |
192 int bytes_read) { | 193 int bytes_read) { |
193 // TODO(noamsml): More sophisticated error handling. | 194 // TODO(noamsml): More sophisticated error handling. |
194 DCHECK_GT(bytes_read, 0); | 195 DCHECK_GT(bytes_read, 0); |
195 delegate_->HandlePacket(response, bytes_read); | 196 delegate_->HandlePacket(response, bytes_read); |
196 } | 197 } |
197 | 198 |
198 MDnsClientImpl::Core::Core() : connection_(new MDnsConnection(this)) { | 199 MDnsClientImpl::Core::Core(base::Clock* clock, base::Timer* timer) |
| 200 : clock_(clock), |
| 201 cleanup_timer_(timer), |
| 202 connection_(new MDnsConnection(this)) { |
199 } | 203 } |
200 | 204 |
201 MDnsClientImpl::Core::~Core() { | 205 MDnsClientImpl::Core::~Core() { |
202 STLDeleteValues(&listeners_); | 206 STLDeleteValues(&listeners_); |
203 } | 207 } |
204 | 208 |
205 bool MDnsClientImpl::Core::Init(MDnsSocketFactory* socket_factory) { | 209 bool MDnsClientImpl::Core::Init(MDnsSocketFactory* socket_factory) { |
206 return connection_->Init(socket_factory); | 210 return connection_->Init(socket_factory); |
207 } | 211 } |
208 | 212 |
(...skipping 25 matching lines...) Expand all Loading... |
234 // TODO(noamsml): duplicate query suppression. | 238 // TODO(noamsml): duplicate query suppression. |
235 if (!(response->flags() & dns_protocol::kFlagResponse)) | 239 if (!(response->flags() & dns_protocol::kFlagResponse)) |
236 return; // Message is a query. ignore it. | 240 return; // Message is a query. ignore it. |
237 | 241 |
238 DnsRecordParser parser = response->Parser(); | 242 DnsRecordParser parser = response->Parser(); |
239 unsigned answer_count = response->answer_count() + | 243 unsigned answer_count = response->answer_count() + |
240 response->additional_answer_count(); | 244 response->additional_answer_count(); |
241 | 245 |
242 for (unsigned i = 0; i < answer_count; i++) { | 246 for (unsigned i = 0; i < answer_count; i++) { |
243 offset = parser.GetOffset(); | 247 offset = parser.GetOffset(); |
244 scoped_ptr<const RecordParsed> record = RecordParsed::CreateFrom( | 248 scoped_ptr<const RecordParsed> record = |
245 &parser, base::Time::Now()); | 249 RecordParsed::CreateFrom(&parser, clock_->Now()); |
246 | 250 |
247 if (!record) { | 251 if (!record) { |
248 DVLOG(1) << "Could not understand an mDNS record."; | 252 DVLOG(1) << "Could not understand an mDNS record."; |
249 | 253 |
250 if (offset == parser.GetOffset()) { | 254 if (offset == parser.GetOffset()) { |
251 DVLOG(1) << "Abandoned parsing the rest of the packet."; | 255 DVLOG(1) << "Abandoned parsing the rest of the packet."; |
252 return; // The parser did not advance, abort reading the packet. | 256 return; // The parser did not advance, abort reading the packet. |
253 } else { | 257 } else { |
254 continue; // We may be able to extract other records from the packet. | 258 continue; // We may be able to extract other records from the packet. |
255 } | 259 } |
(...skipping 32 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
288 } | 292 } |
289 | 293 |
290 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) { | 294 void MDnsClientImpl::Core::NotifyNsecRecord(const RecordParsed* record) { |
291 DCHECK_EQ(dns_protocol::kTypeNSEC, record->type()); | 295 DCHECK_EQ(dns_protocol::kTypeNSEC, record->type()); |
292 const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>(); | 296 const NsecRecordRdata* rdata = record->rdata<NsecRecordRdata>(); |
293 DCHECK(rdata); | 297 DCHECK(rdata); |
294 | 298 |
295 // Remove all cached records matching the nonexistent RR types. | 299 // Remove all cached records matching the nonexistent RR types. |
296 std::vector<const RecordParsed*> records_to_remove; | 300 std::vector<const RecordParsed*> records_to_remove; |
297 | 301 |
298 cache_.FindDnsRecords(0, record->name(), &records_to_remove, | 302 cache_.FindDnsRecords(0, record->name(), &records_to_remove, clock_->Now()); |
299 base::Time::Now()); | |
300 | 303 |
301 for (std::vector<const RecordParsed*>::iterator i = records_to_remove.begin(); | 304 for (std::vector<const RecordParsed*>::iterator i = records_to_remove.begin(); |
302 i != records_to_remove.end(); i++) { | 305 i != records_to_remove.end(); i++) { |
303 if ((*i)->type() == dns_protocol::kTypeNSEC) | 306 if ((*i)->type() == dns_protocol::kTypeNSEC) |
304 continue; | 307 continue; |
305 if (!rdata->GetBit((*i)->type())) { | 308 if (!rdata->GetBit((*i)->type())) { |
306 scoped_ptr<const RecordParsed> record_removed = cache_.RemoveRecord((*i)); | 309 scoped_ptr<const RecordParsed> record_removed = cache_.RemoveRecord((*i)); |
307 DCHECK(record_removed); | 310 DCHECK(record_removed); |
308 OnRecordRemoved(record_removed.get()); | 311 OnRecordRemoved(record_removed.get()); |
309 } | 312 } |
(...skipping 63 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
373 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) { | 376 void MDnsClientImpl::Core::CleanupObserverList(const ListenerKey& key) { |
374 ListenerMap::iterator found = listeners_.find(key); | 377 ListenerMap::iterator found = listeners_.find(key); |
375 if (found != listeners_.end() && !found->second->might_have_observers()) { | 378 if (found != listeners_.end() && !found->second->might_have_observers()) { |
376 delete found->second; | 379 delete found->second; |
377 listeners_.erase(found); | 380 listeners_.erase(found); |
378 } | 381 } |
379 } | 382 } |
380 | 383 |
381 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) { | 384 void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) { |
382 // Cleanup is already scheduled, no need to do anything. | 385 // Cleanup is already scheduled, no need to do anything. |
383 if (cleanup == scheduled_cleanup_) return; | 386 if (cleanup == scheduled_cleanup_) { |
| 387 return; |
| 388 } |
384 scheduled_cleanup_ = cleanup; | 389 scheduled_cleanup_ = cleanup; |
385 | 390 |
386 // This cancels the previously scheduled cleanup. | 391 // This cancels the previously scheduled cleanup. |
387 cleanup_callback_.Reset(base::Bind( | 392 cleanup_timer_->Stop(); |
388 &MDnsClientImpl::Core::DoCleanup, base::Unretained(this))); | |
389 | 393 |
390 // If |cleanup| is empty, then no cleanup necessary. | 394 // If |cleanup| is empty, then no cleanup necessary. |
391 if (cleanup != base::Time()) { | 395 if (cleanup != base::Time()) { |
392 base::MessageLoop::current()->PostDelayedTask( | 396 cleanup_timer_->Start( |
393 FROM_HERE, | 397 FROM_HERE, std::max(base::TimeDelta(), cleanup - clock_->Now()), |
394 cleanup_callback_.callback(), | 398 base::Bind(&MDnsClientImpl::Core::DoCleanup, base::Unretained(this))); |
395 cleanup - base::Time::Now()); | |
396 } | 399 } |
397 } | 400 } |
398 | 401 |
399 void MDnsClientImpl::Core::DoCleanup() { | 402 void MDnsClientImpl::Core::DoCleanup() { |
400 cache_.CleanupRecords(base::Time::Now(), base::Bind( | 403 cache_.CleanupRecords(clock_->Now(), |
401 &MDnsClientImpl::Core::OnRecordRemoved, base::Unretained(this))); | 404 base::Bind(&MDnsClientImpl::Core::OnRecordRemoved, |
| 405 base::Unretained(this))); |
402 | 406 |
403 ScheduleCleanup(cache_.next_expiration()); | 407 ScheduleCleanup(cache_.next_expiration()); |
404 } | 408 } |
405 | 409 |
406 void MDnsClientImpl::Core::OnRecordRemoved( | 410 void MDnsClientImpl::Core::OnRecordRemoved( |
407 const RecordParsed* record) { | 411 const RecordParsed* record) { |
408 AlertListeners(MDnsCache::RecordRemoved, | 412 AlertListeners(MDnsCache::RecordRemoved, |
409 ListenerKey(record->name(), record->type()), record); | 413 ListenerKey(record->name(), record->type()), record); |
410 } | 414 } |
411 | 415 |
412 void MDnsClientImpl::Core::QueryCache( | 416 void MDnsClientImpl::Core::QueryCache( |
413 uint16 rrtype, const std::string& name, | 417 uint16 rrtype, const std::string& name, |
414 std::vector<const RecordParsed*>* records) const { | 418 std::vector<const RecordParsed*>* records) const { |
415 cache_.FindDnsRecords(rrtype, name, records, base::Time::Now()); | 419 cache_.FindDnsRecords(rrtype, name, records, clock_->Now()); |
416 } | 420 } |
417 | 421 |
418 MDnsClientImpl::MDnsClientImpl() { | 422 MDnsClientImpl::MDnsClientImpl() |
| 423 : clock_(new base::DefaultClock), |
| 424 cleanup_timer_(new base::Timer(false, false)) { |
| 425 } |
| 426 |
| 427 MDnsClientImpl::MDnsClientImpl(scoped_ptr<base::Clock> clock, |
| 428 scoped_ptr<base::Timer> timer) |
| 429 : clock_(clock.Pass()), cleanup_timer_(timer.Pass()) { |
419 } | 430 } |
420 | 431 |
421 MDnsClientImpl::~MDnsClientImpl() { | 432 MDnsClientImpl::~MDnsClientImpl() { |
422 } | 433 } |
423 | 434 |
424 bool MDnsClientImpl::StartListening(MDnsSocketFactory* socket_factory) { | 435 bool MDnsClientImpl::StartListening(MDnsSocketFactory* socket_factory) { |
425 DCHECK(!core_.get()); | 436 DCHECK(!core_.get()); |
426 core_.reset(new Core()); | 437 core_.reset(new Core(clock_.get(), cleanup_timer_.get())); |
427 if (!core_->Init(socket_factory)) { | 438 if (!core_->Init(socket_factory)) { |
428 core_.reset(); | 439 core_.reset(); |
429 return false; | 440 return false; |
430 } | 441 } |
431 return true; | 442 return true; |
432 } | 443 } |
433 | 444 |
434 void MDnsClientImpl::StopListening() { | 445 void MDnsClientImpl::StopListening() { |
435 core_.reset(); | 446 core_.reset(); |
436 } | 447 } |
437 | 448 |
438 bool MDnsClientImpl::IsListening() const { | 449 bool MDnsClientImpl::IsListening() const { |
439 return core_.get() != NULL; | 450 return core_.get() != NULL; |
440 } | 451 } |
441 | 452 |
442 scoped_ptr<MDnsListener> MDnsClientImpl::CreateListener( | 453 scoped_ptr<MDnsListener> MDnsClientImpl::CreateListener( |
443 uint16 rrtype, | 454 uint16 rrtype, |
444 const std::string& name, | 455 const std::string& name, |
445 MDnsListener::Delegate* delegate) { | 456 MDnsListener::Delegate* delegate) { |
446 return scoped_ptr<net::MDnsListener>( | 457 return scoped_ptr<net::MDnsListener>( |
447 new MDnsListenerImpl(rrtype, name, delegate, this)); | 458 new MDnsListenerImpl(rrtype, name, clock_.get(), delegate, this)); |
448 } | 459 } |
449 | 460 |
450 scoped_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction( | 461 scoped_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction( |
451 uint16 rrtype, | 462 uint16 rrtype, |
452 const std::string& name, | 463 const std::string& name, |
453 int flags, | 464 int flags, |
454 const MDnsTransaction::ResultCallback& callback) { | 465 const MDnsTransaction::ResultCallback& callback) { |
455 return scoped_ptr<MDnsTransaction>( | 466 return scoped_ptr<MDnsTransaction>( |
456 new MDnsTransactionImpl(rrtype, name, flags, callback, this)); | 467 new MDnsTransactionImpl(rrtype, name, flags, callback, this)); |
457 } | 468 } |
458 | 469 |
459 MDnsListenerImpl::MDnsListenerImpl( | 470 MDnsListenerImpl::MDnsListenerImpl(uint16 rrtype, |
460 uint16 rrtype, | 471 const std::string& name, |
461 const std::string& name, | 472 base::Clock* clock, |
462 MDnsListener::Delegate* delegate, | 473 MDnsListener::Delegate* delegate, |
463 MDnsClientImpl* client) | 474 MDnsClientImpl* client) |
464 : rrtype_(rrtype), name_(name), client_(client), delegate_(delegate), | 475 : rrtype_(rrtype), |
465 started_(false), active_refresh_(false) { | 476 name_(name), |
| 477 clock_(clock), |
| 478 client_(client), |
| 479 delegate_(delegate), |
| 480 started_(false), |
| 481 active_refresh_(false) { |
466 } | 482 } |
467 | 483 |
468 MDnsListenerImpl::~MDnsListenerImpl() { | 484 MDnsListenerImpl::~MDnsListenerImpl() { |
469 if (started_) { | 485 if (started_) { |
470 DCHECK(client_->core()); | 486 DCHECK(client_->core()); |
471 client_->core()->RemoveListener(this); | 487 client_->core()->RemoveListener(this); |
472 } | 488 } |
473 } | 489 } |
474 | 490 |
475 bool MDnsListenerImpl::Start() { | 491 bool MDnsListenerImpl::Start() { |
(...skipping 88 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
564 // response being received. | 580 // response being received. |
565 base::Time next_refresh1 = last_update_ + base::TimeDelta::FromMilliseconds( | 581 base::Time next_refresh1 = last_update_ + base::TimeDelta::FromMilliseconds( |
566 static_cast<int>(base::Time::kMillisecondsPerSecond * | 582 static_cast<int>(base::Time::kMillisecondsPerSecond * |
567 kListenerRefreshRatio1 * ttl_)); | 583 kListenerRefreshRatio1 * ttl_)); |
568 | 584 |
569 base::Time next_refresh2 = last_update_ + base::TimeDelta::FromMilliseconds( | 585 base::Time next_refresh2 = last_update_ + base::TimeDelta::FromMilliseconds( |
570 static_cast<int>(base::Time::kMillisecondsPerSecond * | 586 static_cast<int>(base::Time::kMillisecondsPerSecond * |
571 kListenerRefreshRatio2 * ttl_)); | 587 kListenerRefreshRatio2 * ttl_)); |
572 | 588 |
573 base::MessageLoop::current()->PostDelayedTask( | 589 base::MessageLoop::current()->PostDelayedTask( |
574 FROM_HERE, | 590 FROM_HERE, next_refresh_.callback(), next_refresh1 - clock_->Now()); |
575 next_refresh_.callback(), | |
576 next_refresh1 - base::Time::Now()); | |
577 | 591 |
578 base::MessageLoop::current()->PostDelayedTask( | 592 base::MessageLoop::current()->PostDelayedTask( |
579 FROM_HERE, | 593 FROM_HERE, next_refresh_.callback(), next_refresh2 - clock_->Now()); |
580 next_refresh_.callback(), | |
581 next_refresh2 - base::Time::Now()); | |
582 } | 594 } |
583 | 595 |
584 void MDnsListenerImpl::DoRefresh() { | 596 void MDnsListenerImpl::DoRefresh() { |
585 client_->core()->SendQuery(rrtype_, name_); | 597 client_->core()->SendQuery(rrtype_, name_); |
586 } | 598 } |
587 | 599 |
588 MDnsTransactionImpl::MDnsTransactionImpl( | 600 MDnsTransactionImpl::MDnsTransactionImpl( |
589 uint16 rrtype, | 601 uint16 rrtype, |
590 const std::string& name, | 602 const std::string& name, |
591 int flags, | 603 int flags, |
(...skipping 134 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
726 | 738 |
727 void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) { | 739 void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) { |
728 TriggerCallback(RESULT_NSEC, NULL); | 740 TriggerCallback(RESULT_NSEC, NULL); |
729 } | 741 } |
730 | 742 |
731 void MDnsTransactionImpl::OnCachePurged() { | 743 void MDnsTransactionImpl::OnCachePurged() { |
732 // TODO(noamsml): Cache purge situations not yet implemented | 744 // TODO(noamsml): Cache purge situations not yet implemented |
733 } | 745 } |
734 | 746 |
735 } // namespace net | 747 } // namespace net |
OLD | NEW |