OLD | NEW |
---|---|
1 // Copyright 2017 The Chromium Authors. All rights reserved. | 1 // Copyright 2017 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "components/cast_channel/cast_socket_service.h" | 5 #include "components/cast_channel/cast_socket_service.h" |
6 | 6 |
7 #include "base/memory/ptr_util.h" | 7 #include "base/memory/ptr_util.h" |
8 #include "components/cast_channel/cast_socket.h" | |
9 #include "components/cast_channel/cast_transport.h" | |
10 #include "components/cast_channel/keep_alive_delegate.h" | |
11 #include "components/cast_channel/logger.h" | |
8 #include "content/public/browser/browser_thread.h" | 12 #include "content/public/browser/browser_thread.h" |
9 | 13 |
10 using content::BrowserThread; | 14 using content::BrowserThread; |
11 | 15 |
16 namespace { | |
17 // Connect timeout for connect calls. | |
18 const int kConnectTimeoutSecs = 10; | |
19 | |
20 // Ping interval | |
21 const int kPingIntervalInSecs = 5; | |
22 | |
23 // Liveness timeout for connect calls, in milliseconds. If no message is | |
24 // received from the receiver during LIVENESS_INTERVAL, it is considered gone. | |
25 const int kConnectLivenessTimeoutSecs = kPingIntervalInSecs * 2; | |
26 } // namespace | |
27 | |
12 namespace cast_channel { | 28 namespace cast_channel { |
13 | 29 |
30 PassThroughMessageHandler::PassThroughMessageHandler() {} | |
31 | |
32 PassThroughMessageHandler::~PassThroughMessageHandler() {} | |
33 | |
34 void PassThroughMessageHandler::RegisterDelegate( | |
35 std::unique_ptr<CastTransport::Delegate> delegate) { | |
36 inner_delegate_ = std::move(delegate); | |
37 } | |
38 | |
39 void PassThroughMessageHandler::OnError(ChannelError error_state) { | |
40 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | |
41 if (inner_delegate_) | |
42 inner_delegate_->OnError(error_state); | |
43 } | |
44 | |
45 void PassThroughMessageHandler::OnMessage(const CastMessage& message) { | |
46 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | |
47 if (inner_delegate_) | |
48 inner_delegate_->OnMessage(message); | |
49 } | |
50 | |
51 void PassThroughMessageHandler::Start() {} | |
52 | |
14 int CastSocketService::last_channel_id_ = 0; | 53 int CastSocketService::last_channel_id_ = 0; |
15 | 54 |
16 CastSocketService::CastSocketService() | 55 CastSocketService::CastSocketService() |
17 : RefcountedKeyedService( | 56 : RefcountedKeyedService( |
18 BrowserThread::GetTaskRunnerForThread(BrowserThread::IO)) { | 57 BrowserThread::GetTaskRunnerForThread(BrowserThread::IO)), |
58 logger_(new Logger()) { | |
19 DETACH_FROM_THREAD(thread_checker_); | 59 DETACH_FROM_THREAD(thread_checker_); |
20 } | 60 } |
21 | 61 |
22 CastSocketService::~CastSocketService() { | 62 CastSocketService::~CastSocketService() { |
23 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 63 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
24 } | 64 } |
25 | 65 |
26 int CastSocketService::AddSocket(std::unique_ptr<CastSocket> socket) { | |
27 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | |
28 DCHECK(socket); | |
29 int id = ++last_channel_id_; | |
30 socket->set_id(id); | |
31 sockets_.insert(std::make_pair(id, std::move(socket))); | |
32 return id; | |
33 } | |
34 | |
35 std::unique_ptr<CastSocket> CastSocketService::RemoveSocket(int channel_id) { | 66 std::unique_ptr<CastSocket> CastSocketService::RemoveSocket(int channel_id) { |
36 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 67 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
37 DCHECK(channel_id > 0); | 68 DCHECK(channel_id > 0); |
38 auto socket_it = sockets_.find(channel_id); | 69 auto socket_record_it = socket_records_.find(channel_id); |
39 | 70 |
40 std::unique_ptr<CastSocket> socket; | 71 std::unique_ptr<CastSocket> socket; |
41 if (socket_it != sockets_.end()) { | 72 if (socket_record_it != socket_records_.end()) { |
42 socket = std::move(socket_it->second); | 73 auto* socket_record = socket_record_it->second.get(); |
43 sockets_.erase(socket_it); | 74 socket = std::move(socket_record->cast_socket); |
75 // Invoke all pending OnOpen callbacks. | |
76 for (const auto& on_open_callback : | |
77 socket_record->pending_on_open_callbacks) { | |
78 on_open_callback.Run(socket->id(), ChannelError::CONNECT_ERROR); | |
79 } | |
80 | |
81 socket_records_.erase(socket_record_it); | |
44 } | 82 } |
45 return socket; | 83 return socket; |
46 } | 84 } |
47 | 85 |
48 CastSocket* CastSocketService::GetSocket(int channel_id) const { | 86 CastSocket* CastSocketService::GetSocket(int channel_id) const { |
49 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | 87 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
88 auto* socket_record = GetSocketRecord(channel_id); | |
89 return socket_record ? socket_record->cast_socket.get() : nullptr; | |
90 } | |
91 | |
92 int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint, | |
93 ChannelAuthType channel_auth, | |
94 net::NetLog* net_log, | |
95 const base::TimeDelta& connect_timeout, | |
96 const base::TimeDelta& ping_interval, | |
97 const base::TimeDelta& liveness_timeout, | |
98 const scoped_refptr<Logger>& logger, | |
99 uint64_t device_capabilities, | |
100 const OnOpenCallback& open_cb) { | |
101 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | |
102 | |
103 auto* socket_record = GetSocketRecord(ip_endpoint); | |
104 auto* socket = socket_record ? socket_record->cast_socket.get() : nullptr; | |
105 // If cast socket exists. | |
106 if (socket) { | |
107 switch (socket->ready_state()) { | |
108 case ReadyState::NONE: | |
109 NOTREACHED(); | |
110 break; | |
111 case ReadyState::CONNECTING: | |
112 DCHECK(socket_record->pending_on_open_callbacks.size() > 0); | |
113 socket_record->pending_on_open_callbacks.push_back(open_cb); | |
114 break; | |
115 case ReadyState::OPEN: | |
116 open_cb.Run(socket->id(), ChannelError::NONE); | |
117 break; | |
118 case ReadyState::CLOSING: | |
119 NOTREACHED(); | |
mark a. foltz
2017/06/12 21:14:08
Why is it not possible for the socket to be CLOSIN
zhaobin
2017/06/20 01:37:42
ReadyState::CLOSING is not used anywhere.
| |
120 break; | |
121 case ReadyState::CLOSED: | |
122 RemoveSocket(socket->id()); | |
123 socket = nullptr; | |
124 break; | |
125 default: | |
126 NOTREACHED(); | |
127 } | |
128 if (socket) | |
129 return socket->id(); | |
130 } | |
131 | |
132 // If cast socket does not exist. | |
133 if (socket_for_test_) { | |
134 socket = socket_for_test_.release(); | |
135 } else { | |
136 socket = new CastSocketImpl( | |
137 ip_endpoint, channel_auth, net_log, connect_timeout, | |
138 liveness_timeout > base::TimeDelta(), logger, device_capabilities); | |
139 } | |
140 | |
141 PassThroughMessageHandler* pass_through_delegate = | |
142 new PassThroughMessageHandler(); | |
143 std::unique_ptr<CastTransport::Delegate> delegate(pass_through_delegate); | |
mark a. foltz
2017/06/12 21:14:08
base::MakeUnique?
zhaobin
2017/06/20 01:37:42
Code removed.
| |
144 AddSocketRecord(base::WrapUnique(socket), open_cb, pass_through_delegate); | |
145 | |
146 if (socket->keep_alive()) { | |
147 // Wrap read delegate in a KeepAliveDelegate for timeout handling. | |
148 KeepAliveDelegate* keep_alive = new KeepAliveDelegate( | |
149 socket, logger_, std::move(delegate), ping_interval, liveness_timeout); | |
150 if (injected_timeout_timer_) { | |
151 keep_alive->SetTimersForTest(base::MakeUnique<base::Timer>(false, false), | |
152 std::move(injected_timeout_timer_)); | |
153 } | |
154 delegate.reset(keep_alive); | |
155 } | |
156 | |
157 socket->Connect(std::move(delegate), | |
158 base::Bind(&CastSocketService::OnOpen, this, socket->id())); | |
159 return socket->id(); | |
160 } | |
161 | |
162 int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint, | |
163 net::NetLog* net_log, | |
164 const OnOpenCallback& open_cb) { | |
165 auto connect_timeout = base::TimeDelta::FromSeconds(kConnectTimeoutSecs); | |
166 auto ping_interval = base::TimeDelta::FromSeconds(kPingIntervalInSecs); | |
167 auto liveness_timeout = | |
168 base::TimeDelta::FromSeconds(kConnectLivenessTimeoutSecs); | |
169 return OpenSocket(ip_endpoint, ChannelAuthType::SSL_VERIFIED, net_log, | |
170 connect_timeout, ping_interval, liveness_timeout, logger_, | |
171 CastDeviceCapability::NONE, open_cb); | |
172 } | |
173 | |
174 bool CastSocketService::RegisterDelegate( | |
175 int channel_id, | |
176 std::unique_ptr<CastTransport::Delegate> delegate) { | |
50 DCHECK(channel_id > 0); | 177 DCHECK(channel_id > 0); |
51 const auto& socket_it = sockets_.find(channel_id); | 178 auto* socket_record = GetSocketRecord(channel_id); |
52 return socket_it == sockets_.end() ? nullptr : socket_it->second.get(); | 179 if (!socket_record) |
180 return false; | |
181 | |
182 socket_record->pass_through_message_handler->RegisterDelegate( | |
mark a. foltz
2017/06/12 21:14:08
Why does the CastSocketRecord have a "pass-through
zhaobin
2017/06/20 01:37:42
Code removed. Use observers instead of delegates (
| |
183 std::move(delegate)); | |
184 return true; | |
185 } | |
186 | |
187 void CastSocketService::SetSocketForTest( | |
188 std::unique_ptr<cast_channel::CastSocket> socket_for_test) { | |
189 socket_for_test_ = std::move(socket_for_test); | |
190 } | |
191 | |
192 void CastSocketService::SetPingTimeoutTimerForTest( | |
193 std::unique_ptr<base::Timer> timer) { | |
194 injected_timeout_timer_ = std::move(timer); | |
53 } | 195 } |
54 | 196 |
55 void CastSocketService::ShutdownOnUIThread() {} | 197 void CastSocketService::ShutdownOnUIThread() {} |
56 | 198 |
199 int CastSocketService::AddSocketRecord( | |
200 std::unique_ptr<CastSocket> socket, | |
201 const OnOpenCallback& on_open_callback, | |
202 PassThroughMessageHandler* message_handler) { | |
203 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | |
204 DCHECK(socket); | |
205 int id = ++last_channel_id_; | |
206 socket->set_id(id); | |
207 socket_records_.insert(std::make_pair( | |
208 id, base::MakeUnique<CastSocketRecord>( | |
209 std::move(socket), on_open_callback, message_handler))); | |
210 return id; | |
211 } | |
212 | |
213 CastSocketService::CastSocketRecord* CastSocketService::GetSocketRecord( | |
214 int channel_id) const { | |
215 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | |
216 DCHECK(channel_id > 0); | |
217 const auto& socket_record_it = socket_records_.find(channel_id); | |
218 return socket_record_it == socket_records_.end() | |
219 ? nullptr | |
220 : socket_record_it->second.get(); | |
221 } | |
222 | |
223 CastSocketService::CastSocketRecord* CastSocketService::GetSocketRecord( | |
224 const net::IPEndPoint& ip_endpoint) const { | |
225 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | |
226 for (const auto& socket_record_it : socket_records_) { | |
mark a. foltz
2017/06/12 21:14:08
Consider using std::find_if
zhaobin
2017/06/20 01:37:42
Done.
| |
227 auto* socket_record = socket_record_it.second.get(); | |
228 if (socket_record->cast_socket->ip_endpoint() == ip_endpoint) | |
229 return socket_record; | |
230 } | |
231 return nullptr; | |
232 } | |
233 | |
234 void CastSocketService::OnOpen(int channel_id, ChannelError error_state) { | |
mark a. foltz
2017/06/12 21:14:08
It might make more sense for OnOpen to be declared
zhaobin
2017/06/20 01:37:42
Code removed. Let CastSocket track OnOpen callback
| |
235 DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); | |
236 auto* socket_record = GetSocketRecord(channel_id); | |
237 if (!socket_record) | |
238 return; | |
239 | |
240 // Invoke all pending OnOpen callbacks. | |
241 for (const auto& on_open_callback : | |
242 socket_record->pending_on_open_callbacks) { | |
243 on_open_callback.Run(channel_id, error_state); | |
244 } | |
245 socket_record->pending_on_open_callbacks.clear(); | |
246 } | |
247 | |
248 CastSocketService::CastSocketRecord::CastSocketRecord( | |
249 std::unique_ptr<CastSocket> socket, | |
250 const OnOpenCallback& on_open_callback, | |
251 PassThroughMessageHandler* message_handler) | |
252 : cast_socket(std::move(socket)), | |
253 pass_through_message_handler(message_handler) { | |
254 DCHECK(cast_socket); | |
255 DCHECK(pass_through_message_handler); | |
256 pending_on_open_callbacks.push_back(on_open_callback); | |
257 } | |
258 | |
259 CastSocketService::CastSocketRecord::~CastSocketRecord() {} | |
260 | |
57 } // namespace cast_channel | 261 } // namespace cast_channel |
OLD | NEW |