| OLD | NEW |
| 1 // Copyright 2013 The Chromium Authors. All rights reserved. | 1 // Copyright 2013 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "chrome/utility/local_discovery/service_discovery_message_handler.h" | 5 #include "chrome/utility/local_discovery/service_discovery_message_handler.h" |
| 6 | 6 |
| 7 #include "base/command_line.h" |
| 7 #include "chrome/common/local_discovery/local_discovery_messages.h" | 8 #include "chrome/common/local_discovery/local_discovery_messages.h" |
| 8 #include "chrome/utility/local_discovery/service_discovery_client_impl.h" | 9 #include "chrome/utility/local_discovery/service_discovery_client_impl.h" |
| 10 #include "content/public/common/content_switches.h" |
| 9 #include "content/public/utility/utility_thread.h" | 11 #include "content/public/utility/utility_thread.h" |
| 10 | 12 |
| 13 #if defined(OS_WIN) |
| 14 |
| 15 #include "base/lazy_instance.h" |
| 16 #include "net/base/winsock_init.h" |
| 17 #include "net/base/winsock_util.h" |
| 18 |
| 19 #endif // OS_WIN |
| 20 |
| 21 namespace { |
| 22 |
| 23 bool NeedsSockets() { |
| 24 return !CommandLine::ForCurrentProcess()->HasSwitch(switches::kNoSandbox) && |
| 25 CommandLine::ForCurrentProcess()->HasSwitch( |
| 26 switches::kUtilityProcessEnableMDns); |
| 27 } |
| 28 |
| 29 #if defined(OS_WIN) |
| 30 |
| 31 class SocketFactory : public net::PlatformSocketFactory { |
| 32 public: |
| 33 SocketFactory() |
| 34 : socket_v4_(NULL), |
| 35 socket_v6_(NULL) { |
| 36 net::EnsureWinsockInit(); |
| 37 socket_v4_ = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, NULL, 0, |
| 38 WSA_FLAG_OVERLAPPED); |
| 39 socket_v6_ = WSASocket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP, NULL, 0, |
| 40 WSA_FLAG_OVERLAPPED); |
| 41 } |
| 42 |
| 43 void Reset() { |
| 44 if (socket_v4_ != INVALID_SOCKET) { |
| 45 closesocket(socket_v4_); |
| 46 socket_v4_ = INVALID_SOCKET; |
| 47 } |
| 48 if (socket_v6_ != INVALID_SOCKET) { |
| 49 closesocket(socket_v6_); |
| 50 socket_v6_ = INVALID_SOCKET; |
| 51 } |
| 52 } |
| 53 |
| 54 virtual ~SocketFactory() { |
| 55 Reset(); |
| 56 } |
| 57 |
| 58 virtual SOCKET CreateSocket(int family, int type, int protocol) OVERRIDE { |
| 59 SOCKET result = INVALID_SOCKET; |
| 60 if (type != SOCK_DGRAM && protocol != IPPROTO_UDP) { |
| 61 NOTREACHED(); |
| 62 } else if (family == AF_INET) { |
| 63 std::swap(result, socket_v4_); |
| 64 } else if (family == AF_INET6) { |
| 65 std::swap(result, socket_v6_); |
| 66 } |
| 67 return result; |
| 68 } |
| 69 |
| 70 SOCKET socket_v4_; |
| 71 SOCKET socket_v6_; |
| 72 |
| 73 DISALLOW_COPY_AND_ASSIGN(SocketFactory); |
| 74 }; |
| 75 |
| 76 base::LazyInstance<SocketFactory> |
| 77 g_local_discovery_socket_factory = LAZY_INSTANCE_INITIALIZER; |
| 78 |
| 79 class ScopedSocketFactorySetter { |
| 80 public: |
| 81 ScopedSocketFactorySetter() { |
| 82 if (NeedsSockets()) { |
| 83 net::PlatformSocketFactory::SetInstance( |
| 84 &g_local_discovery_socket_factory.Get()); |
| 85 } |
| 86 } |
| 87 |
| 88 ~ScopedSocketFactorySetter() { |
| 89 if (NeedsSockets()) { |
| 90 net::PlatformSocketFactory::SetInstance(NULL); |
| 91 g_local_discovery_socket_factory.Get().Reset(); |
| 92 } |
| 93 } |
| 94 |
| 95 static void Initialize() { |
| 96 if (NeedsSockets()) { |
| 97 g_local_discovery_socket_factory.Get(); |
| 98 } |
| 99 } |
| 100 |
| 101 private: |
| 102 DISALLOW_COPY_AND_ASSIGN(ScopedSocketFactorySetter); |
| 103 }; |
| 104 |
| 105 #else // OS_WIN |
| 106 |
| 107 class ScopedSocketFactorySetter { |
| 108 public: |
| 109 ScopedSocketFactorySetter() {} |
| 110 |
| 111 static void Initialize() { |
| 112 // TODO(vitalybuka) : implement socket access from sandbox for other |
| 113 // platforms. |
| 114 DCHECK(!NeedsSockets()); |
| 115 } |
| 116 }; |
| 117 |
| 118 #endif // OS_WIN |
| 119 |
| 120 } // namespace |
| 121 |
| 11 namespace local_discovery { | 122 namespace local_discovery { |
| 12 | 123 |
| 13 ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() { | 124 ServiceDiscoveryMessageHandler::ServiceDiscoveryMessageHandler() { |
| 14 } | 125 } |
| 15 | 126 |
| 16 ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() { | 127 ServiceDiscoveryMessageHandler::~ServiceDiscoveryMessageHandler() { |
| 17 } | 128 } |
| 18 | 129 |
| 19 void ServiceDiscoveryMessageHandler::Initialize() { | 130 void ServiceDiscoveryMessageHandler::PreSandboxStartup() { |
| 20 if (!service_discovery_client_) { | 131 ScopedSocketFactorySetter::Initialize(); |
| 21 mdns_client_ = net::MDnsClient::CreateDefault(); | 132 } |
| 22 mdns_client_->StartListening(); | 133 |
| 23 service_discovery_client_.reset( | 134 bool ServiceDiscoveryMessageHandler::Initialize() { |
| 24 new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get())); | 135 if (service_discovery_client_) |
| 136 return true; |
| 137 |
| 138 if (mdns_client_) // We tried but failed before. |
| 139 return false; |
| 140 |
| 141 mdns_client_ = net::MDnsClient::CreateDefault(); |
| 142 { |
| 143 // Temporarily redirect network code to use pre-created sockets. |
| 144 ScopedSocketFactorySetter socket_factory_setter; |
| 145 if (!mdns_client_->StartListening()) |
| 146 return false; |
| 25 } | 147 } |
| 148 |
| 149 service_discovery_client_.reset( |
| 150 new local_discovery::ServiceDiscoveryClientImpl(mdns_client_.get())); |
| 151 return true; |
| 26 } | 152 } |
| 27 | 153 |
| 28 bool ServiceDiscoveryMessageHandler::OnMessageReceived( | 154 bool ServiceDiscoveryMessageHandler::OnMessageReceived( |
| 29 const IPC::Message& message) { | 155 const IPC::Message& message) { |
| 30 bool handled = true; | 156 bool handled = true; |
| 31 IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message) | 157 IPC_BEGIN_MESSAGE_MAP(ServiceDiscoveryMessageHandler, message) |
| 32 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher) | 158 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_StartWatcher, OnStartWatcher) |
| 33 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices) | 159 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DiscoverServices, OnDiscoverServices) |
| 34 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher) | 160 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyWatcher, OnDestroyWatcher) |
| 35 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService) | 161 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_ResolveService, OnResolveService) |
| 36 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver) | 162 IPC_MESSAGE_HANDLER(LocalDiscoveryMsg_DestroyResolver, OnDestroyResolver) |
| 37 IPC_MESSAGE_UNHANDLED(handled = false) | 163 IPC_MESSAGE_UNHANDLED(handled = false) |
| 38 IPC_END_MESSAGE_MAP() | 164 IPC_END_MESSAGE_MAP() |
| 39 return handled; | 165 return handled; |
| 40 } | 166 } |
| 41 | 167 |
| 42 void ServiceDiscoveryMessageHandler::OnStartWatcher( | 168 void ServiceDiscoveryMessageHandler::OnStartWatcher( |
| 43 uint64 id, | 169 uint64 id, |
| 44 const std::string& service_type) { | 170 const std::string& service_type) { |
| 45 Initialize(); | 171 if (!Initialize()) |
| 172 return; |
| 46 DCHECK(!ContainsKey(service_watchers_, id)); | 173 DCHECK(!ContainsKey(service_watchers_, id)); |
| 47 scoped_ptr<ServiceWatcher> watcher( | 174 scoped_ptr<ServiceWatcher> watcher( |
| 48 service_discovery_client_->CreateServiceWatcher( | 175 service_discovery_client_->CreateServiceWatcher( |
| 49 service_type, | 176 service_type, |
| 50 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated, | 177 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceUpdated, |
| 51 base::Unretained(this), id))); | 178 base::Unretained(this), id))); |
| 52 watcher->Start(); | 179 watcher->Start(); |
| 53 service_watchers_[id].reset(watcher.release()); | 180 service_watchers_[id].reset(watcher.release()); |
| 54 } | 181 } |
| 55 | 182 |
| 56 void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id, | 183 void ServiceDiscoveryMessageHandler::OnDiscoverServices(uint64 id, |
| 57 bool force_update) { | 184 bool force_update) { |
| 185 if (!service_discovery_client_) |
| 186 return; |
| 58 DCHECK(ContainsKey(service_watchers_, id)); | 187 DCHECK(ContainsKey(service_watchers_, id)); |
| 59 service_watchers_[id]->DiscoverNewServices(force_update); | 188 service_watchers_[id]->DiscoverNewServices(force_update); |
| 60 } | 189 } |
| 61 | 190 |
| 62 void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) { | 191 void ServiceDiscoveryMessageHandler::OnDestroyWatcher(uint64 id) { |
| 192 if (!service_discovery_client_) |
| 193 return; |
| 63 DCHECK(ContainsKey(service_watchers_, id)); | 194 DCHECK(ContainsKey(service_watchers_, id)); |
| 64 service_watchers_.erase(id); | 195 service_watchers_.erase(id); |
| 65 } | 196 } |
| 66 | 197 |
| 67 void ServiceDiscoveryMessageHandler::OnResolveService( | 198 void ServiceDiscoveryMessageHandler::OnResolveService( |
| 68 uint64 id, | 199 uint64 id, |
| 69 const std::string& service_name) { | 200 const std::string& service_name) { |
| 70 Initialize(); | 201 if (!Initialize()) |
| 202 return; |
| 71 DCHECK(!ContainsKey(service_resolvers_, id)); | 203 DCHECK(!ContainsKey(service_resolvers_, id)); |
| 72 scoped_ptr<ServiceResolver> resolver( | 204 scoped_ptr<ServiceResolver> resolver( |
| 73 service_discovery_client_->CreateServiceResolver( | 205 service_discovery_client_->CreateServiceResolver( |
| 74 service_name, | 206 service_name, |
| 75 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved, | 207 base::Bind(&ServiceDiscoveryMessageHandler::OnServiceResolved, |
| 76 base::Unretained(this), id))); | 208 base::Unretained(this), id))); |
| 77 resolver->StartResolving(); | 209 resolver->StartResolving(); |
| 78 service_resolvers_[id].reset(resolver.release()); | 210 service_resolvers_[id].reset(resolver.release()); |
| 79 } | 211 } |
| 80 | 212 |
| 81 void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) { | 213 void ServiceDiscoveryMessageHandler::OnDestroyResolver(uint64 id) { |
| 214 if (!service_discovery_client_) |
| 215 return; |
| 82 DCHECK(ContainsKey(service_resolvers_, id)); | 216 DCHECK(ContainsKey(service_resolvers_, id)); |
| 83 service_resolvers_.erase(id); | 217 service_resolvers_.erase(id); |
| 84 } | 218 } |
| 85 | 219 |
| 86 void ServiceDiscoveryMessageHandler::OnServiceUpdated( | 220 void ServiceDiscoveryMessageHandler::OnServiceUpdated( |
| 87 uint64 id, | 221 uint64 id, |
| 88 ServiceWatcher::UpdateType update, | 222 ServiceWatcher::UpdateType update, |
| 89 const std::string& name) { | 223 const std::string& name) { |
| 224 DCHECK(service_discovery_client_); |
| 90 content::UtilityThread::Get()->Send( | 225 content::UtilityThread::Get()->Send( |
| 91 new LocalDiscoveryHostMsg_WatcherCallback(id, update, name)); | 226 new LocalDiscoveryHostMsg_WatcherCallback(id, update, name)); |
| 92 } | 227 } |
| 93 | 228 |
| 94 void ServiceDiscoveryMessageHandler::OnServiceResolved( | 229 void ServiceDiscoveryMessageHandler::OnServiceResolved( |
| 95 uint64 id, | 230 uint64 id, |
| 96 ServiceResolver::RequestStatus status, | 231 ServiceResolver::RequestStatus status, |
| 97 const ServiceDescription& description) { | 232 const ServiceDescription& description) { |
| 233 DCHECK(service_discovery_client_); |
| 98 content::UtilityThread::Get()->Send( | 234 content::UtilityThread::Get()->Send( |
| 99 new LocalDiscoveryHostMsg_ResolverCallback(id, status, description)); | 235 new LocalDiscoveryHostMsg_ResolverCallback(id, status, description)); |
| 100 } | 236 } |
| 101 | 237 |
| 102 } // namespace local_discovery | 238 } // namespace local_discovery |
| 239 |
| OLD | NEW |