Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(246)

Unified Diff: net/dns/mdns_listener_impl.cc

Issue 15733008: Multicast DNS implementation (initial) (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@mdns_implementation2
Patch Set: Created 7 years, 7 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
Index: net/dns/mdns_listener_impl.cc
diff --git a/net/dns/mdns_listener_impl.cc b/net/dns/mdns_listener_impl.cc
new file mode 100644
index 0000000000000000000000000000000000000000..94eaa6f74d2f60fe57bef8b231ca783b2dfc1856
--- /dev/null
+++ b/net/dns/mdns_listener_impl.cc
@@ -0,0 +1,555 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/mdns_listener_impl.h"
+
+#include "base/bind.h"
+#include "base/time/default_clock.h"
+#include "base/message_loop_proxy.h"
+#include "net/base/dns_util.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/rand_callback.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/mdns_query.h"
+#include "net/udp/datagram_socket.h"
+
+namespace net {
+
+static const unsigned kMDnsTransactionTimeoutSeconds = 3;
+
+MDnsListenerFactoryImpl::Core::Core(MDnsListenerFactoryImpl* factory)
+ : socket_ipv4_(new UDPSocket(DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ NULL, NetLog::Source())),
+ socket_ipv6_(new UDPSocket(DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ NULL, NetLog::Source())),
+ response_ipv4_(new DnsResponse(dns_protocol::kMaxMulticastSize)),
+ response_ipv6_(new DnsResponse(dns_protocol::kMaxMulticastSize)),
+ factory_(factory) {
+}
+
+MDnsListenerFactoryImpl::Core::~Core() {
+ cleanup_callback_.Cancel();
+ socket_ipv4_->Close();
+ socket_ipv6_->Close();
+}
+
+bool MDnsListenerFactoryImpl::Core::Init() {
+ if (!BindSocket(socket_ipv4_.get(), kIPv4AddressSize,
+ kMDNSMulticastGroupIPv4))
+ return false;
+
+ if (!BindSocket(socket_ipv6_.get(), kIPv6AddressSize,
+ kMDNSMulticastGroupIPv6))
+ return false;
+
+ if (!RecieveOnePacket(socket_ipv4_.get(),
szym 2013/05/24 15:32:22 What is the point of doing this?
Noam Samuel 2013/05/24 19:00:49 Oh. I guess this method is somewhat misnamed. Reci
Noam Samuel 2013/05/29 21:25:16 Renamed method and moved it to MDnsConnectionImpl.
+ response_ipv4_.get(),
+ &recv_addr_ipv4_))
+ return false;
+
+ if (!RecieveOnePacket(socket_ipv6_.get(),
+ response_ipv6_.get(),
+ &recv_addr_ipv6_))
+ return false;
+
+ return true;
+}
+
+bool MDnsListenerFactoryImpl::Core::SendQuery(uint16 rrtype, std::string name) {
+ std::string name_dns;
+ if (!DNSDomainFromDot(name, &name_dns))
+ return false;
+
+ MDnsQuery query(name_dns, rrtype);
+
+ int rv = socket_ipv4_->SendTo(
+ query.io_buffer(),
+ query.size(),
+ factory_->GetIPv4SendEndpoint(),
+ base::Bind(&MDnsListenerFactoryImpl::Core::SendDone,
+ AsWeakPtr()));
+ if (rv < OK && rv != ERR_IO_PENDING) return false;
+
+ rv = socket_ipv6_->SendTo(
+ query.io_buffer(),
+ query.size(),
+ factory_->GetIPv6SendEndpoint(),
+ base::Bind(&MDnsListenerFactoryImpl::Core::SendDone,
+ AsWeakPtr()));
+ if (rv < OK && rv != ERR_IO_PENDING) return false;
+
+ return true;
+}
+
+void MDnsListenerFactoryImpl::Core::SendDone(int sent) {
+ // TODO(noamsml): Queueing and retry logic
+}
+
+bool MDnsListenerFactoryImpl::Core::BindSocket(
+ UDPSocket* socket,
+ int addr_size,
+ const char* multicast_group) {
+ IPAddressNumber address_any;
+ address_any.resize(addr_size, 0);
+
+ IPAddressNumber multicast_group_number;
+
+ IPEndPoint bind_endpoint(address_any, dns_protocol::kDefaultPortMulticast);
+
+ bool success = ParseIPLiteralToNumber(multicast_group,
+ &multicast_group_number);
+ DCHECK(success);
+
+ return factory_->BindToAddressAndMulticastGroup(
+ socket, bind_endpoint, multicast_group_number);
+}
+
+bool MDnsListenerFactoryImpl::Core::RecieveOnePacket(
+ UDPSocket* socket,
+ DnsResponse* response,
+ IPEndPoint* recv_addr) {
+ int rval = socket->RecvFrom(
+ response->io_buffer(),
+ response->io_buffer()->size(),
+ recv_addr,
+ base::Bind(&MDnsListenerFactoryImpl::Core::OnDatagramRecieved,
+ AsWeakPtr(), socket, response, recv_addr));
szym 2013/05/25 02:51:01 You don't really need WeakPtr here since if Core i
Noam Samuel 2013/05/29 21:25:16 Done.
+
+ if (rval > 0) {
+ factory_->task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(&MDnsListenerFactoryImpl::Core::OnDatagramRecieved,
szym 2013/05/25 02:51:01 Why not call synchronously?
Noam Samuel 2013/05/29 17:53:03 If an attacker were to spam the mdns port with pac
szym 2013/05/29 17:59:41 The solution to this is to replace recursion with
Noam Samuel 2013/05/29 21:25:16 Done.
+ AsWeakPtr(),
+ socket,
+ response,
+ recv_addr,
+ rval));
+ } else {
+ if (rval != ERR_IO_PENDING) return false;
+ }
+
+ return true;
+}
+
+void MDnsListenerFactoryImpl::Core::OnDatagramRecieved(
+ UDPSocket* socket,
+ DnsResponse* response,
+ IPEndPoint* recv_addr,
+ int bytes_read) {
+ // TODO(noamsml): More sophisticated error handling.
+ DCHECK_GT(bytes_read, 0);
+ HandlePacket(response, bytes_read);
+ bool success = RecieveOnePacket(socket, response, recv_addr);
+
+ DCHECK(success); // TODO(noamsml): exponential backoff.
+}
+
+void MDnsListenerFactoryImpl::Core::HandlePacket(DnsResponse* response,
+ int bytes_read) {
+ unsigned offset;
+
+ if (!response->InitParseWithoutQuery(bytes_read)) {
+ LOG(WARNING) << "Could not understand an mDNS packet.";
+ return; // Message is unreadable.
+ }
+
+ // TODO(noamsml): duplicate query suppression.
+ if (!(response->flags() & dns_protocol::kFlagResponse)) {
+ return; // Message is a query. ignore it.
+ }
+ DnsRecordParser parser = response->Parser();
+ unsigned answer_count = response->answer_count() +
+ response->additional_answer_count();
+
+ for (unsigned i = 0; i < answer_count; i++) {
+ offset = parser.GetOffset();
+ scoped_ptr<const RecordParsed> scoped_record = RecordParsed::CreateFrom(
+ &parser, factory_->clock_->Now());
+
+ if (!scoped_record) {
+ LOG(WARNING) << "Could not understand an mDNS record.";
+
+ if (offset == parser.GetOffset()) {
+ LOG(WARNING) << "Abandoned parsing the rest of the packet.";
+ return; // The parser did not advance, abort reading the packet.
+ } else {
+ continue; // We may be able to extract other records from the packet.
+ }
+ }
+
+ if ((scoped_record->klass() & dns_protocol::kMDnsClassMask) !=
+ dns_protocol::kClassIN) {
+ LOG(WARNING) << "Recieved an mDNS record with non-IN class. Ignoring.";
+ continue; // Ignore all records not in the IN class.
+ }
+
+ // We want to retain a copy of the record pointer for updating listeners
+ // but we are passing ownership to the cache.
+ const RecordParsed* record = scoped_record.get();
+ MDnsCache::UpdateType update = cache_.UpdateDnsRecord(scoped_record.Pass());
+
+ // Cleanup time may have changed.
+ ScheduleCleanup(cache_.next_expiration());
+
+ if (update != MDnsCache::NoChange) {
+ MDnsUpdateType update_external;
+
+ switch (update) {
+ case MDnsCache::RecordAdded:
+ update_external = kMDnsRecordAdded;
+ break;
+ case MDnsCache::RecordChanged:
+ update_external = kMDnsRecordChanged;
+ break;
+ case MDnsCache::NoChange:
+ NOTREACHED();
+ // Dummy assignment to suppress compiler warning.
+ update_external = kMDnsRecordChanged;
+ break;
+ }
+
+ AlertListeners(update_external,
+ ListenerKey(record->type(), record->name()), record);
+ // Alert listeners listening only for rrtype and not for name.
+ AlertListeners(update_external, ListenerKey(record->type(), ""), record);
+ }
+ }
+}
+
+void MDnsListenerFactoryImpl::Core::AlertListeners(
+ MDnsUpdateType update_type,
+ const ListenerKey& key,
+ const RecordParsed* record) {
+ std::pair<ListenersIterator, ListenersIterator> listeners =
+ listeners_.equal_range(key);
+ for (ListenersIterator l = listeners.first; l != listeners.second; ++l) {
+ scoped_ptr<const RecordParsed> record_clone = record->Clone();
+ factory_->task_runner_->PostTask(FROM_HERE, base::Bind(
+ &MDnsListenerImpl::AlertDelegate, l->second->AsWeakPtr(), update_type,
+ base::Owned(record_clone.release())) );
szym 2013/05/24 15:32:22 Why not just call: l->second->AlertDelegate(update
Noam Samuel 2013/05/24 19:00:49 Listener delegates may create or delete any listen
Noam Samuel 2013/05/24 22:34:20 On second thought, deferring shutting down listeni
szym 2013/05/25 02:51:01 You could consider using std::map<ListenerKey, Obs
Noam Samuel 2013/05/29 17:53:03 Taking a scoped reference in the class entrypoints
szym 2013/05/29 17:59:41 Ok, but now you have to deal with a situation wher
Noam Samuel 2013/05/29 18:35:35 Each individual listener is added/removed in an RA
szym 2013/05/29 18:40:43 With deferred deletion, how do you avoid calling b
Noam Samuel 2013/05/29 20:14:49 Oh, sorry. The deferred deletion applies only to t
+ }
+}
+
+void MDnsListenerFactoryImpl::Core::AddListener(
+ MDnsListenerImpl* listener, bool alert_existing_records) {
+ ListenerKey key(listener->GetType(), listener->GetName());
+
+ listeners_.insert(std::pair<ListenerKey, MDnsListenerImpl*>(key, listener));
+
+ if (alert_existing_records) {
+ std::vector<const RecordParsed*> records;
+
+ cache_.FindDnsRecords(listener->GetType(), listener->GetName(),
+ &records, factory_->clock_->Now());
+
+ for (std::vector<const RecordParsed*>::iterator i = records.begin();
+ i != records.end(); i++) {
+ scoped_ptr<const RecordParsed> record_clone = (*i)->Clone();
+ factory_->task_runner_->PostTask(FROM_HERE, base::Bind(
+ &MDnsListenerImpl::AlertDelegate, listener->AsWeakPtr(),
+ kMDnsRecordAdded, base::Owned(record_clone.release())) );
+ }
+ }
+}
+
+void MDnsListenerFactoryImpl::Core::RemoveListener(MDnsListenerImpl* listener) {
+ ListenerKey key(listener->GetType(), listener->GetName());
+
+ bool removed = false;
+
+ std::pair<ListenersIterator, ListenersIterator> listeners =
+ listeners_.equal_range(key);
+ for (ListenersIterator l = listeners.first; l != listeners.second;
+ l++) {
+ if (l->second == listener) {
+ listeners_.erase(l);
+ removed = true;
+ break;
+ }
+ }
+ DCHECK(removed);
+
+ // When we remove a listener, we notify factory that a listen reference has
+ // been removed. This may cause the core to be deleted.
+ factory_->SubtractListenRef();
+}
+
+void MDnsListenerFactoryImpl::Core::ScheduleCleanup(base::Time cleanup) {
+ // Cleanup is already scheduled, no need to do anything.
+ if (cleanup == scheduled_cleanup_) return;
+ scheduled_cleanup_ = cleanup;
+
+ // This line has the effect of cancelling the previously scheduled cleanup.
+ cleanup_callback_.Reset(base::Bind(
+ &MDnsListenerFactoryImpl::Core::DoCleanup, base::Unretained(this)));
+
+ // cleanup == base::Time means no cleanup necessary.
+ if (cleanup != base::Time()) {
+ factory_->task_runner_->PostDelayedTask(
+ FROM_HERE,
+ cleanup_callback_.callback(),
+ cleanup - factory_->clock_->Now());
+ }
+}
+
+void MDnsListenerFactoryImpl::Core::DoCleanup() {
+ cache_.CleanupRecords(factory_->clock_->Now(), base::Bind(
+ &MDnsListenerFactoryImpl::Core::OnRecordRemoved, base::Unretained(this)));
+
+ ScheduleCleanup(cache_.next_expiration());
+}
+
+void MDnsListenerFactoryImpl::Core::OnRecordRemoved(
+ const RecordParsed* record) {
+ AlertListeners(kMDnsRecordRemoved,
+ ListenerKey(record->type(), record->name()), record);
+ // Alert listeners listening only for rrtype and not for name.
+ AlertListeners(kMDnsRecordRemoved, ListenerKey(record->type(), ""),
+ record);
+}
+
+void MDnsListenerFactoryImpl::Core::QueryCache(
+ uint16 rrtype, const std::string& name,
+ std::vector<const RecordParsed*>* records) const {
+ cache_.FindDnsRecords(rrtype, name, records, factory_->clock_->Now());
+}
+
+MDnsListenerFactoryImpl::MDnsListenerFactoryImpl()
+ : listen_refs_(0), clock_owned_(new base::DefaultClock()),
+ task_runner_(base::MessageLoopProxy::current()) {
+ clock_ = clock_owned_.get();
+}
+
+MDnsListenerFactoryImpl::MDnsListenerFactoryImpl(base::Clock* clock,
+ base::TaskRunner* task_runner)
+ : listen_refs_(0), clock_(clock), task_runner_(task_runner) {
+}
+
+MDnsListenerFactoryImpl::~MDnsListenerFactoryImpl() {
+}
+
+bool MDnsListenerFactoryImpl::AddListenRef() {
+ if (listen_refs_ == 0) {
+ core_.reset(new Core(this));
+ if (!core_->Init()) {
+ core_.reset();
+ return false;
+ }
+ }
+ listen_refs_++;
+ return true;
+}
+
+void MDnsListenerFactoryImpl::SubtractListenRef() {
+ listen_refs_--;
+ if (listen_refs_ == 0) core_.reset(NULL);
+}
+
+bool MDnsListenerFactoryImpl::BindToAddressAndMulticastGroup(
+ UDPSocket* socket,
+ const IPEndPoint& address_bind,
+ const IPAddressNumber& group_multicast) {
+ socket->AllowAddressReuse();
+ int status = socket->Bind(address_bind);
+
+ if (status < 0)
+ return false;
+
+ socket->SetMulticastLoopbackMode(false);
+
+ status = socket->JoinGroup(group_multicast);
+
+ if (status < 0)
+ return false;
+
+ return true;
+}
+
+bool MDnsListenerFactoryImpl::IsListeningForTests() {
+ return core_.get() != NULL;
+}
+
+IPEndPoint MDnsListenerFactoryImpl::GetIPv4SendEndpoint() {
+ IPAddressNumber multicast_group_number;
+ bool success = ParseIPLiteralToNumber(kMDNSMulticastGroupIPv4,
+ &multicast_group_number);
+ DCHECK(success);
+ return IPEndPoint(multicast_group_number,
+ dns_protocol::kDefaultPortMulticast);
+}
+
+IPEndPoint MDnsListenerFactoryImpl::GetIPv6SendEndpoint() {
+ IPAddressNumber multicast_group_number;
+ bool success = ParseIPLiteralToNumber(kMDNSMulticastGroupIPv6,
+ &multicast_group_number);
+ DCHECK(success);
+ return IPEndPoint(multicast_group_number,
+ dns_protocol::kDefaultPortMulticast);
+}
+
+scoped_ptr<MDnsListener> MDnsListenerFactoryImpl::CreateListener(
+ uint16 rrtype,
+ const std::string& name,
+ bool active,
+ bool alert_existing_records,
+ MDnsListenerFactory::Delegate* delegate) {
+ if (!AddListenRef()) return scoped_ptr<net::MDnsListener>();
+
+ return scoped_ptr<net::MDnsListener>(
+ new MDnsListenerImpl(rrtype, name, active, alert_existing_records,
+ delegate, core_.get()));
+}
+
+scoped_ptr<MDnsTransaction> MDnsListenerFactoryImpl::CreateTransaction(
+ uint16 rrtype,
+ const std::string& name,
+ const MDnsListenerFactory::QueryCallback& callback) {
+ scoped_ptr<MDnsTransactionImpl> transaction(
+ new MDnsTransactionImpl(rrtype, name, callback, task_runner_));
+
+ if (transaction->Init(this, core_.get())) {
+ return scoped_ptr<MDnsTransaction>(transaction.Pass());
+ } else {
+ return scoped_ptr<MDnsTransaction>();
+ }
+}
+
+MDnsListenerImpl::MDnsListenerImpl(
+ uint16 rrtype,
+ const std::string& name,
+ bool active,
+ bool alert_existing_records,
+ MDnsListenerFactory::Delegate* delegate,
+ MDnsListenerFactoryImpl::Core* core)
+ : rrtype_(rrtype), name_(name), active_(active),
+ parent_(core), delegate_(delegate) {
+ parent_->AddListener(this, alert_existing_records);
+
+ if (active) SendQuery(false); // TODO(noamsml): Retry logic.
+}
+
+MDnsListenerImpl::~MDnsListenerImpl() {
+ parent_->RemoveListener(this);
+}
+
+const std::string& MDnsListenerImpl::GetName() const {
+ return name_;
+}
+
+uint16 MDnsListenerImpl::GetType() const {
+ return rrtype_;
+}
+
+bool MDnsListenerImpl::IsActive() const {
+ return active_;
+}
+
+bool MDnsListenerImpl::SendQuery(bool force_refresh_cache) {
+ // TODO(noamsml): Logic related to force_refresh_cache
+ if (name_.size() == 0) return false;
+ return parent_->SendQuery(rrtype_, name_);
+}
+
+bool MDnsListenerImpl::QueryCache(
+ std::vector<const RecordParsed*>* records) const {
+ if (name_.size() == 0) return false;
+ parent_->QueryCache(rrtype_, name_, records);
+ return true;
+}
+
+void MDnsListenerImpl::AlertDelegate(MDnsUpdateType update_type,
+ const RecordParsed* record) {
+ delegate_->OnRecordUpdate(update_type, record);
+}
+
+MDnsTransactionImpl::MDnsTransactionImpl(
+ uint16 rrtype,
+ const std::string& name,
+ const MDnsListenerFactory::QueryCallback& callback,
+ base::TaskRunner* task_runner)
+ : rrtype_(rrtype), name_(name), callback_(callback), triggered_(false),
+ task_runner_(task_runner) {
+}
+
+MDnsTransactionImpl::~MDnsTransactionImpl() {
+}
+
+bool MDnsTransactionImpl::Init(
+ MDnsListenerFactoryImpl* factory,
+ MDnsListenerFactoryImpl::Core* core) {
+ DCHECK(factory);
+ std::vector<const RecordParsed*> records;
+ if (core) {
+ core->QueryCache(rrtype_, name_, &records);
+ if (!records.empty()) {
+ scoped_ptr<const RecordParsed> record_clone = records.front()->Clone();
+ task_runner_->PostTask(
+ FROM_HERE,
+ base::Bind(&MDnsTransactionImpl::CacheRecordFound,
+ AsWeakPtr(), base::Owned(
+ record_clone.release())) );
+
+ return true;
+ }
+ }
+
+ listener_ = factory->CreateListener(rrtype_, name_, true /*active*/,
+ false /*alert existing*/, this);
+
+ timeout_.Reset(base::Bind(&MDnsTransactionImpl::OnTimedOut, AsWeakPtr()));
+ task_runner_->PostDelayedTask(
+ FROM_HERE,
+ timeout_.callback(),
+ base::TimeDelta::FromSeconds(kMDnsTransactionTimeoutSeconds));
+
+ return listener_.get() != NULL;
+}
+
+const std::string& MDnsTransactionImpl::GetName() const {
+ return name_;
+}
+
+uint16 MDnsTransactionImpl::GetType() const {
+ return rrtype_;
+}
+
+void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) {
+ OnRecordUpdate(kMDnsRecordAdded, record);
+}
+
+void MDnsTransactionImpl::TriggerCallback(MDnsTransactionResult result,
+ const RecordParsed* record) {
+ if (triggered_) return;
+ triggered_ = true;
+
+ // Ensure callback is run after touching all class state, so that
+ // the callback can delete the transaction.
+ MDnsListenerFactory::QueryCallback callback = callback_;
+
+ callback_.Reset();
+ listener_.reset();
+ timeout_.Cancel();
+
+ callback.Run(result, record);
+}
+
+void MDnsTransactionImpl::OnRecordUpdate(MDnsUpdateType update,
+ const RecordParsed* record) {
+ if (update == kMDnsRecordAdded || update == kMDnsRecordChanged) {
+ TriggerCallback(kMDnsTransactionSuccess, record);
+ }
+}
+
+void MDnsTransactionImpl::OnTimedOut() {
+ TriggerCallback(kMDnsTransactionTimeout, NULL);
+}
+
+void MDnsTransactionImpl::OnNSecRecord(const std::string& name, unsigned type) {
+ // TODO(noamsml): NSEC records not yet implemented
+}
+
+} // namespace net

Powered by Google App Engine
This is Rietveld 408576698