Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(100)

Unified Diff: net/dns/mdns_client_impl.cc

Issue 15733008: Multicast DNS implementation (initial) (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@mdns_implementation2
Patch Set: Renamed files from "listener" to "client" Created 7 years, 7 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
Index: net/dns/mdns_client_impl.cc
diff --git a/net/dns/mdns_client_impl.cc b/net/dns/mdns_client_impl.cc
new file mode 100644
index 0000000000000000000000000000000000000000..e63b17fa9629601caf365307e30f1e2eb5f54682
--- /dev/null
+++ b/net/dns/mdns_client_impl.cc
@@ -0,0 +1,600 @@
+// Copyright (c) 2013 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "net/dns/mdns_client_impl.h"
+
+#include "base/bind.h"
+#include "base/message_loop_proxy.h"
+#include "base/stl_util.h"
+#include "base/time/default_clock.h"
+#include "net/base/dns_util.h"
+#include "net/base/net_errors.h"
+#include "net/base/net_log.h"
+#include "net/base/rand_callback.h"
+#include "net/dns/dns_protocol.h"
+#include "net/dns/mdns_query.h"
+#include "net/udp/datagram_socket.h"
+
+namespace net {
+
+static const char kMDNSMulticastGroupIPv4[] = "224.0.0.251";
+static const char kMDNSMulticastGroupIPv6[] = "FF02::FB";
+
+static const unsigned kMDnsTransactionTimeoutSeconds = 3;
+
+MDnsClientImpl::Core::Core(MDnsClientImpl* client,
+ MDnsConnectionFactory* connection_factory,
+ base::TaskRunner* task_runner,
+ base::Clock* clock)
+ : client_(client), task_runner_(task_runner), clock_(clock),
szym 2013/06/02 19:01:23 The only reason I see for this dependency on TaskR
Noam Samuel 2013/06/04 00:08:03 Done.
+ connection_(connection_factory->CreateConnection(this, task_runner)) {
+}
+
+MDnsClientImpl::Core::~Core() {
+ cleanup_callback_.Cancel();
+ STLDeleteValues(&listeners_);
+}
+
+bool MDnsClientImpl::Core::Init() {
+ return connection_->Init();
+}
+
+bool MDnsClientImpl::Core::SendQuery(uint16 rrtype, std::string name) {
+ std::string name_dns;
+ if (!DNSDomainFromDot(name, &name_dns))
+ return false;
+
+ MDnsQuery query(name_dns, rrtype);
+
+ connection_->Send(query.io_buffer(), query.size());
+
+ return true;
+}
+
+void MDnsClientImpl::Core::HandlePacket(DnsResponse* response,
+ int bytes_read) {
+ unsigned offset;
+
+ if (!response->InitParseWithoutQuery(bytes_read)) {
+ LOG(WARNING) << "Could not understand an mDNS packet.";
+ return; // Message is unreadable.
+ }
+
+ // TODO(noamsml): duplicate query suppression.
+ if (!(response->flags() & dns_protocol::kFlagResponse)) {
+ return; // Message is a query. ignore it.
+ }
+ DnsRecordParser parser = response->Parser();
+ unsigned answer_count = response->answer_count() +
+ response->additional_answer_count();
+
+ for (unsigned i = 0; i < answer_count; i++) {
+ offset = parser.GetOffset();
+ scoped_ptr<const RecordParsed> scoped_record = RecordParsed::CreateFrom(
+ &parser, clock_->Now());
+
+ if (!scoped_record) {
+ LOG(WARNING) << "Could not understand an mDNS record.";
+
+ if (offset == parser.GetOffset()) {
+ LOG(WARNING) << "Abandoned parsing the rest of the packet.";
+ return; // The parser did not advance, abort reading the packet.
+ } else {
+ continue; // We may be able to extract other records from the packet.
+ }
+ }
+
+ if ((scoped_record->klass() & dns_protocol::kMDnsClassMask) !=
+ dns_protocol::kClassIN) {
+ LOG(WARNING) << "Received an mDNS record with non-IN class. Ignoring.";
+ continue; // Ignore all records not in the IN class.
+ }
+
+ // We want to retain a copy of the record pointer for updating listeners
+ // but we are passing ownership to the cache.
+ const RecordParsed* record = scoped_record.get();
+ MDnsCache::UpdateType update = cache_.UpdateDnsRecord(scoped_record.Pass());
+
+ // Cleanup time may have changed.
+ ScheduleCleanup(cache_.next_expiration());
+
+ if (update != MDnsCache::NoChange) {
+ MDnsUpdateType update_external;
+
+ switch (update) {
+ case MDnsCache::RecordAdded:
+ update_external = kMDnsRecordAdded;
+ break;
+ case MDnsCache::RecordChanged:
+ update_external = kMDnsRecordChanged;
+ break;
+ case MDnsCache::NoChange:
+ NOTREACHED();
+ // Dummy assignment to suppress compiler warning.
+ update_external = kMDnsRecordChanged;
+ break;
+ }
+
+ AlertListeners(update_external,
+ ListenerKey(record->type(), record->name()), record);
+ // Alert listeners listening only for rrtype and not for name.
+ AlertListeners(update_external, ListenerKey(record->type(), ""), record);
+ }
+ }
+}
+
+void MDnsClientImpl::Core::AlertListeners(
+ MDnsUpdateType update_type,
+ const ListenerKey& key,
+ const RecordParsed* record) {
+ ListenerMap::iterator listener_map_iterator = listeners_.find(key);
+ if (listener_map_iterator == listeners_.end()) return;
+
+ FOR_EACH_OBSERVER(MDnsListenerImpl, *listener_map_iterator->second,
+ AlertDelegate(update_type, record));
+}
+
+void MDnsClientImpl::Core::AddListener(
+ MDnsListenerImpl* listener, bool alert_existing_records) {
+ ListenerKey key(listener->GetType(), listener->GetName());
+ std::pair<ListenerMap::iterator, bool> observer_insert_result =
+ listeners_.insert(
+ make_pair(key, static_cast<ObserverList<MDnsListenerImpl>*>(NULL)));
+
+ // If an equivalent key does not exist, actually create the observer list.
+ if (observer_insert_result.second) {
+ observer_insert_result.first->second = new ObserverList<MDnsListenerImpl>();
+ }
+
+ ObserverList<MDnsListenerImpl>* observer_list =
+ observer_insert_result.first->second;
+
+ observer_list->AddObserver(listener);
+
+ if (alert_existing_records) {
+ std::vector<const RecordParsed*> records;
+
+ cache_.FindDnsRecords(listener->GetType(), listener->GetName(),
+ &records, clock_->Now());
+
+ for (std::vector<const RecordParsed*>::iterator i = records.begin();
+ i != records.end(); i++) {
+ listener->AlertDelegate(kMDnsRecordAdded, *i);
+ }
+ }
+}
+
+void MDnsClientImpl::Core::RemoveListener(MDnsListenerImpl* listener) {
+ ListenerKey key(listener->GetType(), listener->GetName());
+ ListenerMap::iterator observer_list_iterator = listeners_.find(key);
+
+ DCHECK(observer_list_iterator != listeners_.end());
+ DCHECK(observer_list_iterator->second->HasObserver(listener));
+
+ observer_list_iterator->second->RemoveObserver(listener);
+
+ // Remove the observer list from the map if it is empty
+ if (observer_list_iterator->second->size() == 0) {
+ delete observer_list_iterator->second;
+ listeners_.erase(observer_list_iterator);
+ }
+
+ // When we remove a listener, we notify client that a listen reference has
+ // been removed. This may cause the core to be deleted.
+ client_->SubtractListenRef();
+}
+
+void MDnsClientImpl::Core::ScheduleCleanup(base::Time cleanup) {
+ // Cleanup is already scheduled, no need to do anything.
+ if (cleanup == scheduled_cleanup_) return;
+ scheduled_cleanup_ = cleanup;
+
+ // This line has the effect of cancelling the previously scheduled cleanup.
+ cleanup_callback_.Reset(base::Bind(
+ &MDnsClientImpl::Core::DoCleanup, base::Unretained(this)));
+
+ // cleanup == base::Time means no cleanup necessary.
+ if (cleanup != base::Time()) {
+ task_runner_->PostDelayedTask(
+ FROM_HERE,
+ cleanup_callback_.callback(),
+ cleanup - clock_->Now());
+ }
+}
+
+void MDnsClientImpl::Core::DoCleanup() {
+ cache_.CleanupRecords(clock_->Now(), base::Bind(
+ &MDnsClientImpl::Core::OnRecordRemoved, base::Unretained(this)));
+
+ ScheduleCleanup(cache_.next_expiration());
+}
+
+void MDnsClientImpl::Core::OnRecordRemoved(
+ const RecordParsed* record) {
+ AlertListeners(kMDnsRecordRemoved,
+ ListenerKey(record->type(), record->name()), record);
+ // Alert listeners listening only for rrtype and not for name.
+ AlertListeners(kMDnsRecordRemoved, ListenerKey(record->type(), ""),
+ record);
+}
+
+void MDnsClientImpl::Core::QueryCache(
+ uint16 rrtype, const std::string& name,
+ std::vector<const RecordParsed*>* records) const {
+ cache_.FindDnsRecords(rrtype, name, records, clock_->Now());
+}
+
+MDnsClientImpl::MDnsClientImpl()
+ : listen_refs_(0), clock_owned_(new base::DefaultClock()),
+ connection_factory_owned_(new MDnsConnectionImplFactory()),
+ task_runner_(base::MessageLoopProxy::current()) {
+ clock_ = clock_owned_.get();
+ connection_factory_ = connection_factory_owned_.get();
+}
+
+MDnsClientImpl::MDnsClientImpl(base::Clock* clock,
+ base::TaskRunner* task_runner,
+ MDnsConnectionFactory* connection_factory)
+ : listen_refs_(0), clock_(clock), connection_factory_(connection_factory),
+ task_runner_(task_runner) {
+}
+
+MDnsClientImpl::~MDnsClientImpl() {
+}
+
+bool MDnsClientImpl::AddListenRef() {
+ if (!core_.get()) {
+ core_.reset(new Core(this, connection_factory_, task_runner_, clock_));
+ if (!core_->Init()) {
+ core_.reset();
+ return false;
+ }
+ }
+ listen_refs_++;
+ return true;
+}
+
+void MDnsClientImpl::SubtractListenRef() {
+ listen_refs_--;
+ if (listen_refs_ == 0) {
+ task_runner_->PostTask(FROM_HERE, base::Bind(
+ &MDnsClientImpl::Shutdown, base::Unretained(this)));
+ }
+}
+
+void MDnsClientImpl::Shutdown() {
+ // We need to check that new listeners haven't been created.
+ if (listen_refs_ == 0) {
+ core_.reset();
+ }
+}
+
+bool MDnsClientImpl::IsListeningForTests() {
+ return core_.get() != NULL;
+}
+
+scoped_ptr<MDnsListener> MDnsClientImpl::CreateListener(
+ uint16 rrtype,
+ const std::string& name,
+ bool active,
+ bool alert_existing_records,
+ MDnsListener::Delegate* delegate) {
+ if (!AddListenRef()) return scoped_ptr<net::MDnsListener>();
+
+ return scoped_ptr<net::MDnsListener>(
+ new MDnsListenerImpl(rrtype, name, active, alert_existing_records,
+ delegate, core_.get()));
+}
+
+scoped_ptr<MDnsTransaction> MDnsClientImpl::CreateTransaction(
+ uint16 rrtype,
+ const std::string& name,
+ const MDnsTransaction::ResultCallback& callback) {
+ scoped_ptr<MDnsTransactionImpl> transaction(
+ new MDnsTransactionImpl(rrtype, name, callback, task_runner_));
+
+ if (transaction->Init(this, core_.get())) {
+ return scoped_ptr<MDnsTransaction>(transaction.Pass());
+ } else {
+ return scoped_ptr<MDnsTransaction>();
+ }
+}
+
+MDnsListenerImpl::MDnsListenerImpl(
+ uint16 rrtype,
+ const std::string& name,
+ bool active,
+ bool alert_existing_records,
+ MDnsListener::Delegate* delegate,
+ MDnsClientImpl::Core* core)
+ : rrtype_(rrtype), name_(name), active_(active),
+ parent_(core), delegate_(delegate) {
+ parent_->AddListener(this, alert_existing_records);
+
+ if (active) SendQuery(false); // TODO(noamsml): Retry logic.
+}
+
+MDnsListenerImpl::~MDnsListenerImpl() {
+ parent_->RemoveListener(this);
+}
+
+const std::string& MDnsListenerImpl::GetName() const {
+ return name_;
+}
+
+uint16 MDnsListenerImpl::GetType() const {
+ return rrtype_;
+}
+
+bool MDnsListenerImpl::IsActive() const {
+ return active_;
+}
+
+bool MDnsListenerImpl::SendQuery(bool force_refresh_cache) {
+ // TODO(noamsml): Logic related to force_refresh_cache
+ if (name_.size() == 0) return false;
+ return parent_->SendQuery(rrtype_, name_);
+}
+
+bool MDnsListenerImpl::QueryCache(
+ std::vector<const RecordParsed*>* records) const {
+ if (name_.size() == 0) return false;
+ parent_->QueryCache(rrtype_, name_, records);
+ return true;
+}
+
+void MDnsListenerImpl::AlertDelegate(MDnsUpdateType update_type,
+ const RecordParsed* record) {
+ delegate_->OnRecordUpdate(update_type, record);
+}
+
+MDnsTransactionImpl::MDnsTransactionImpl(
+ uint16 rrtype,
+ const std::string& name,
+ const MDnsTransaction::ResultCallback& callback,
+ base::TaskRunner* task_runner)
+ : rrtype_(rrtype), name_(name), callback_(callback), triggered_(false),
+ task_runner_(task_runner) {
+}
+
+MDnsTransactionImpl::~MDnsTransactionImpl() {
+}
+
+bool MDnsTransactionImpl::Init(
+ MDnsClientImpl* client,
+ MDnsClientImpl::Core* core) {
+ DCHECK(client);
+ std::vector<const RecordParsed*> records;
+ if (core) {
+ core->QueryCache(rrtype_, name_, &records);
+ if (!records.empty()) {
+ scoped_ptr<const RecordParsed> record_clone = records.front()->Clone();
szym 2013/06/02 19:01:23 Why do you need to Clone?
Noam Samuel 2013/06/04 00:08:03 Removed.
+ task_runner_->PostTask(
szym 2013/06/02 19:01:23 Why not call CacheRecordFound directly?
Noam Samuel 2013/06/04 00:08:03 Synchronized.
+ FROM_HERE,
+ base::Bind(&MDnsTransactionImpl::CacheRecordFound,
+ AsWeakPtr(), base::Owned(
+ record_clone.release())) );
+
+ return true;
+ }
+ }
+
+ listener_ = client->CreateListener(rrtype_, name_, true /*active*/,
+ false /*alert existing*/, this);
+
+ timeout_.Reset(base::Bind(&MDnsTransactionImpl::OnTimedOut, AsWeakPtr()));
+ task_runner_->PostDelayedTask(
+ FROM_HERE,
+ timeout_.callback(),
+ base::TimeDelta::FromSeconds(kMDnsTransactionTimeoutSeconds));
+
+ return listener_.get() != NULL;
+}
+
+const std::string& MDnsTransactionImpl::GetName() const {
+ return name_;
+}
+
+uint16 MDnsTransactionImpl::GetType() const {
+ return rrtype_;
+}
+
+void MDnsTransactionImpl::CacheRecordFound(const RecordParsed* record) {
+ OnRecordUpdate(kMDnsRecordAdded, record);
+}
+
+void MDnsTransactionImpl::TriggerCallback(MDnsTransactionResult result,
+ const RecordParsed* record) {
+ if (triggered_) return;
+ triggered_ = true;
szym 2013/06/02 19:01:23 Instead of adding |triggered_|, I suggest: if (cal
Noam Samuel 2013/06/04 00:08:03 Done.
+
+ // Ensure callback is run after touching all class state, so that
+ // the callback can delete the transaction.
+ MDnsTransaction::ResultCallback callback = callback_;
+
+ callback_.Reset();
+ listener_.reset();
+ timeout_.Cancel();
+
+ callback.Run(result, record);
+}
+
+void MDnsTransactionImpl::OnRecordUpdate(MDnsUpdateType update,
+ const RecordParsed* record) {
+ if (update == kMDnsRecordAdded || update == kMDnsRecordChanged) {
+ TriggerCallback(kMDnsTransactionSuccess, record);
+ }
+}
+
+void MDnsTransactionImpl::OnTimedOut() {
+ TriggerCallback(kMDnsTransactionTimeout, NULL);
+}
+
+void MDnsTransactionImpl::OnNsecRecord(const std::string& name, unsigned type) {
+ // TODO(noamsml): NSEC records not yet implemented
+}
+
+MDnsConnectionImpl::MDnsConnectionImpl(MDnsConnection::Delegate* delegate,
+ base::TaskRunner* task_runner)
+ : socket_ipv4_(new UDPSocket(DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ NULL, NetLog::Source())),
+ socket_ipv6_(new UDPSocket(DatagramSocket::DEFAULT_BIND,
+ RandIntCallback(),
+ NULL, NetLog::Source())),
+ response_ipv4_(new DnsResponse(dns_protocol::kMaxMulticastSize)),
+ response_ipv6_(new DnsResponse(dns_protocol::kMaxMulticastSize)),
+ delegate_(delegate),
+ task_runner_(task_runner) {
+}
+
+MDnsConnectionImpl::~MDnsConnectionImpl() {
+ socket_ipv4_->Close();
+ socket_ipv6_->Close();
+}
+
+bool MDnsConnectionImpl::Init() {
+ if (!BindSocket(socket_ipv4_.get(), kIPv4AddressSize,
+ kMDNSMulticastGroupIPv4))
+ return false;
+
+ if (!BindSocket(socket_ipv6_.get(), kIPv6AddressSize,
+ kMDNSMulticastGroupIPv6))
+ return false;
+
+ if (!ReceiveNextPacket(socket_ipv4_.get(),
+ response_ipv4_.get(),
+ &recv_addr_ipv4_))
+ return false;
+
+ if (!ReceiveNextPacket(socket_ipv6_.get(),
+ response_ipv6_.get(),
+ &recv_addr_ipv6_))
+ return false;
+
+ return true;
+}
+
+bool MDnsConnectionImpl::Send(IOBuffer* buffer, unsigned size) {
+ int rv = socket_ipv4_->SendTo(
+ buffer,
+ size,
+ GetIPv4SendEndpoint(),
+ base::Bind(&MDnsConnectionImpl::SendDone,
+ base::Unretained(this) ));
+ if (rv < OK && rv != ERR_IO_PENDING) return false;
+
+ rv = socket_ipv6_->SendTo(
+ buffer,
+ size,
+ GetIPv6SendEndpoint(),
+ base::Bind(&MDnsConnectionImpl::SendDone,
+ base::Unretained(this) ));
+ if (rv < OK && rv != ERR_IO_PENDING) return false;
+
+ return true;
+}
+
+void MDnsConnectionImpl::SendDone(int sent) {
+ // TODO(noamsml): Queueing and retry logic
+}
+
+bool MDnsConnectionImpl::BindSocket(
+ UDPSocket* socket,
+ int addr_size,
+ const char* multicast_group) {
+ IPAddressNumber address_any;
+ address_any.resize(addr_size, 0);
+
+ IPAddressNumber multicast_group_number;
+
+ IPEndPoint bind_endpoint(address_any, dns_protocol::kDefaultPortMulticast);
+
+ bool success = ParseIPLiteralToNumber(multicast_group,
+ &multicast_group_number);
+ DCHECK(success);
+
+ socket->AllowAddressReuse();
+ int status = socket->Bind(bind_endpoint);
+
+ if (status < 0)
+ return false;
+
+ socket->SetMulticastLoopbackMode(false);
+
+ status = socket->JoinGroup(multicast_group_number);
+
+ if (status < 0)
+ return false;
+
+ return true;
+}
+
+IPEndPoint MDnsConnectionImpl::GetIPv4SendEndpoint() {
+ IPAddressNumber multicast_group_number;
+ bool success = ParseIPLiteralToNumber(kMDNSMulticastGroupIPv4,
+ &multicast_group_number);
+ DCHECK(success);
+ return IPEndPoint(multicast_group_number,
+ dns_protocol::kDefaultPortMulticast);
+}
+
+IPEndPoint MDnsConnectionImpl::GetIPv6SendEndpoint() {
+ IPAddressNumber multicast_group_number;
+ bool success = ParseIPLiteralToNumber(kMDNSMulticastGroupIPv6,
+ &multicast_group_number);
+ DCHECK(success);
+ return IPEndPoint(multicast_group_number,
+ dns_protocol::kDefaultPortMulticast);
+}
+
+void MDnsConnectionImpl::OnDatagramReceived(
+ UDPSocket* socket,
+ DnsResponse* response,
+ IPEndPoint* recv_addr,
+ int bytes_read) {
+ // TODO(noamsml): More sophisticated error handling.
+ DCHECK_GT(bytes_read, 0);
+ delegate_->HandlePacket(response, bytes_read);
+ bool success = ReceiveNextPacket(socket, response, recv_addr);
+
+ DCHECK(success); // TODO(noamsml): exponential backoff.
+}
+
+bool MDnsConnectionImpl::ReceiveNextPacket(
+ UDPSocket* socket,
+ DnsResponse* response,
+ IPEndPoint* recv_addr) {
+ int rval;
+ do {
+ rval = socket->RecvFrom(
+ response->io_buffer(),
+ response->io_buffer()->size(),
+ recv_addr,
+ base::Bind(&MDnsConnectionImpl::OnDatagramReceived,
+ base::Unretained(this), socket, response, recv_addr));
+
+ if (rval > 0) {
+ delegate_->HandlePacket(response, rval);
+ }
+ } while (rval > 0);
+
+ if (rval != ERR_IO_PENDING) return false;
+ return true;
+}
+
+MDnsConnectionImplFactory::MDnsConnectionImplFactory() {
+}
+
+MDnsConnectionImplFactory::~MDnsConnectionImplFactory() {
+}
+
+scoped_ptr<MDnsConnection> MDnsConnectionImplFactory::CreateConnection(
+ MDnsConnection::Delegate* delegate,
+ base::TaskRunner* task_runner) {
+ return scoped_ptr<MDnsConnection>(new MDnsConnectionImpl(delegate,
+ task_runner));
+}
+
+} // namespace net

Powered by Google App Engine
This is Rietveld 408576698