Index: components/cast_channel/cast_socket_service.cc |
diff --git a/components/cast_channel/cast_socket_service.cc b/components/cast_channel/cast_socket_service.cc |
index f58c9858e9212121ed2ad48758d411773747d708..b731bf540a4dfddf3b4cce31d3f22a7a234a4202 100644 |
--- a/components/cast_channel/cast_socket_service.cc |
+++ b/components/cast_channel/cast_socket_service.cc |
@@ -5,17 +5,57 @@ |
#include "components/cast_channel/cast_socket_service.h" |
#include "base/memory/ptr_util.h" |
+#include "components/cast_channel/cast_socket.h" |
+#include "components/cast_channel/cast_transport.h" |
+#include "components/cast_channel/keep_alive_delegate.h" |
+#include "components/cast_channel/logger.h" |
#include "content/public/browser/browser_thread.h" |
using content::BrowserThread; |
+namespace { |
+// Connect timeout for connect calls. |
+const int kConnectTimeoutSecs = 10; |
+ |
+// Ping interval |
+const int kPingIntervalInSecs = 5; |
+ |
+// Liveness timeout for connect calls, in milliseconds. If no message is |
+// received from the receiver during LIVENESS_INTERVAL, it is considered gone. |
+const int kConnectLivenessTimeoutSecs = kPingIntervalInSecs * 2; |
+} // namespace |
+ |
namespace cast_channel { |
+PassThroughMessageHandler::PassThroughMessageHandler() {} |
+ |
+PassThroughMessageHandler::~PassThroughMessageHandler() {} |
+ |
+void PassThroughMessageHandler::RegisterDelegate( |
+ std::unique_ptr<CastTransport::Delegate> delegate) { |
+ inner_delegate_ = std::move(delegate); |
+} |
+ |
+void PassThroughMessageHandler::OnError(ChannelError error_state) { |
+ DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
+ if (inner_delegate_) |
+ inner_delegate_->OnError(error_state); |
+} |
+ |
+void PassThroughMessageHandler::OnMessage(const CastMessage& message) { |
+ DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
+ if (inner_delegate_) |
+ inner_delegate_->OnMessage(message); |
+} |
+ |
+void PassThroughMessageHandler::Start() {} |
+ |
int CastSocketService::last_channel_id_ = 0; |
CastSocketService::CastSocketService() |
: RefcountedKeyedService( |
- BrowserThread::GetTaskRunnerForThread(BrowserThread::IO)) { |
+ BrowserThread::GetTaskRunnerForThread(BrowserThread::IO)), |
+ logger_(new Logger()) { |
DETACH_FROM_THREAD(thread_checker_); |
} |
@@ -23,35 +63,199 @@ CastSocketService::~CastSocketService() { |
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
} |
-int CastSocketService::AddSocket(std::unique_ptr<CastSocket> socket) { |
+std::unique_ptr<CastSocket> CastSocketService::RemoveSocket(int channel_id) { |
+ DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
+ DCHECK(channel_id > 0); |
+ auto socket_record_it = socket_records_.find(channel_id); |
+ |
+ std::unique_ptr<CastSocket> socket; |
+ if (socket_record_it != socket_records_.end()) { |
+ auto* socket_record = socket_record_it->second.get(); |
+ socket = std::move(socket_record->cast_socket); |
+ // Invoke all pending OnOpen callbacks. |
+ for (const auto& on_open_callback : |
+ socket_record->pending_on_open_callbacks) { |
+ on_open_callback.Run(socket->id(), ChannelError::CONNECT_ERROR); |
+ } |
+ |
+ socket_records_.erase(socket_record_it); |
+ } |
+ return socket; |
+} |
+ |
+CastSocket* CastSocketService::GetSocket(int channel_id) const { |
+ DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
+ auto* socket_record = GetSocketRecord(channel_id); |
+ return socket_record ? socket_record->cast_socket.get() : nullptr; |
+} |
+ |
+int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint, |
+ ChannelAuthType channel_auth, |
+ net::NetLog* net_log, |
+ const base::TimeDelta& connect_timeout, |
+ const base::TimeDelta& ping_interval, |
+ const base::TimeDelta& liveness_timeout, |
+ const scoped_refptr<Logger>& logger, |
+ uint64_t device_capabilities, |
+ const OnOpenCallback& open_cb) { |
+ DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
+ |
+ auto* socket_record = GetSocketRecord(ip_endpoint); |
+ auto* socket = socket_record ? socket_record->cast_socket.get() : nullptr; |
+ // If cast socket exists. |
+ if (socket) { |
+ switch (socket->ready_state()) { |
+ case ReadyState::NONE: |
+ NOTREACHED(); |
+ break; |
+ case ReadyState::CONNECTING: |
+ DCHECK(socket_record->pending_on_open_callbacks.size() > 0); |
+ socket_record->pending_on_open_callbacks.push_back(open_cb); |
+ break; |
+ case ReadyState::OPEN: |
+ open_cb.Run(socket->id(), ChannelError::NONE); |
+ break; |
+ case ReadyState::CLOSING: |
+ NOTREACHED(); |
mark a. foltz
2017/06/12 21:14:08
Why is it not possible for the socket to be CLOSIN
zhaobin
2017/06/20 01:37:42
ReadyState::CLOSING is not used anywhere.
|
+ break; |
+ case ReadyState::CLOSED: |
+ RemoveSocket(socket->id()); |
+ socket = nullptr; |
+ break; |
+ default: |
+ NOTREACHED(); |
+ } |
+ if (socket) |
+ return socket->id(); |
+ } |
+ |
+ // If cast socket does not exist. |
+ if (socket_for_test_) { |
+ socket = socket_for_test_.release(); |
+ } else { |
+ socket = new CastSocketImpl( |
+ ip_endpoint, channel_auth, net_log, connect_timeout, |
+ liveness_timeout > base::TimeDelta(), logger, device_capabilities); |
+ } |
+ |
+ PassThroughMessageHandler* pass_through_delegate = |
+ new PassThroughMessageHandler(); |
+ std::unique_ptr<CastTransport::Delegate> delegate(pass_through_delegate); |
mark a. foltz
2017/06/12 21:14:08
base::MakeUnique?
zhaobin
2017/06/20 01:37:42
Code removed.
|
+ AddSocketRecord(base::WrapUnique(socket), open_cb, pass_through_delegate); |
+ |
+ if (socket->keep_alive()) { |
+ // Wrap read delegate in a KeepAliveDelegate for timeout handling. |
+ KeepAliveDelegate* keep_alive = new KeepAliveDelegate( |
+ socket, logger_, std::move(delegate), ping_interval, liveness_timeout); |
+ if (injected_timeout_timer_) { |
+ keep_alive->SetTimersForTest(base::MakeUnique<base::Timer>(false, false), |
+ std::move(injected_timeout_timer_)); |
+ } |
+ delegate.reset(keep_alive); |
+ } |
+ |
+ socket->Connect(std::move(delegate), |
+ base::Bind(&CastSocketService::OnOpen, this, socket->id())); |
+ return socket->id(); |
+} |
+ |
+int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint, |
+ net::NetLog* net_log, |
+ const OnOpenCallback& open_cb) { |
+ auto connect_timeout = base::TimeDelta::FromSeconds(kConnectTimeoutSecs); |
+ auto ping_interval = base::TimeDelta::FromSeconds(kPingIntervalInSecs); |
+ auto liveness_timeout = |
+ base::TimeDelta::FromSeconds(kConnectLivenessTimeoutSecs); |
+ return OpenSocket(ip_endpoint, ChannelAuthType::SSL_VERIFIED, net_log, |
+ connect_timeout, ping_interval, liveness_timeout, logger_, |
+ CastDeviceCapability::NONE, open_cb); |
+} |
+ |
+bool CastSocketService::RegisterDelegate( |
+ int channel_id, |
+ std::unique_ptr<CastTransport::Delegate> delegate) { |
+ DCHECK(channel_id > 0); |
+ auto* socket_record = GetSocketRecord(channel_id); |
+ if (!socket_record) |
+ return false; |
+ |
+ socket_record->pass_through_message_handler->RegisterDelegate( |
mark a. foltz
2017/06/12 21:14:08
Why does the CastSocketRecord have a "pass-through
zhaobin
2017/06/20 01:37:42
Code removed. Use observers instead of delegates (
|
+ std::move(delegate)); |
+ return true; |
+} |
+ |
+void CastSocketService::SetSocketForTest( |
+ std::unique_ptr<cast_channel::CastSocket> socket_for_test) { |
+ socket_for_test_ = std::move(socket_for_test); |
+} |
+ |
+void CastSocketService::SetPingTimeoutTimerForTest( |
+ std::unique_ptr<base::Timer> timer) { |
+ injected_timeout_timer_ = std::move(timer); |
+} |
+ |
+void CastSocketService::ShutdownOnUIThread() {} |
+ |
+int CastSocketService::AddSocketRecord( |
+ std::unique_ptr<CastSocket> socket, |
+ const OnOpenCallback& on_open_callback, |
+ PassThroughMessageHandler* message_handler) { |
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
DCHECK(socket); |
int id = ++last_channel_id_; |
socket->set_id(id); |
- sockets_.insert(std::make_pair(id, std::move(socket))); |
+ socket_records_.insert(std::make_pair( |
+ id, base::MakeUnique<CastSocketRecord>( |
+ std::move(socket), on_open_callback, message_handler))); |
return id; |
} |
-std::unique_ptr<CastSocket> CastSocketService::RemoveSocket(int channel_id) { |
+CastSocketService::CastSocketRecord* CastSocketService::GetSocketRecord( |
+ int channel_id) const { |
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
DCHECK(channel_id > 0); |
- auto socket_it = sockets_.find(channel_id); |
+ const auto& socket_record_it = socket_records_.find(channel_id); |
+ return socket_record_it == socket_records_.end() |
+ ? nullptr |
+ : socket_record_it->second.get(); |
+} |
- std::unique_ptr<CastSocket> socket; |
- if (socket_it != sockets_.end()) { |
- socket = std::move(socket_it->second); |
- sockets_.erase(socket_it); |
+CastSocketService::CastSocketRecord* CastSocketService::GetSocketRecord( |
+ const net::IPEndPoint& ip_endpoint) const { |
+ DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
+ for (const auto& socket_record_it : socket_records_) { |
mark a. foltz
2017/06/12 21:14:08
Consider using std::find_if
zhaobin
2017/06/20 01:37:42
Done.
|
+ auto* socket_record = socket_record_it.second.get(); |
+ if (socket_record->cast_socket->ip_endpoint() == ip_endpoint) |
+ return socket_record; |
} |
- return socket; |
+ return nullptr; |
} |
-CastSocket* CastSocketService::GetSocket(int channel_id) const { |
+void CastSocketService::OnOpen(int channel_id, ChannelError error_state) { |
mark a. foltz
2017/06/12 21:14:08
It might make more sense for OnOpen to be declared
zhaobin
2017/06/20 01:37:42
Code removed. Let CastSocket track OnOpen callback
|
DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
- DCHECK(channel_id > 0); |
- const auto& socket_it = sockets_.find(channel_id); |
- return socket_it == sockets_.end() ? nullptr : socket_it->second.get(); |
+ auto* socket_record = GetSocketRecord(channel_id); |
+ if (!socket_record) |
+ return; |
+ |
+ // Invoke all pending OnOpen callbacks. |
+ for (const auto& on_open_callback : |
+ socket_record->pending_on_open_callbacks) { |
+ on_open_callback.Run(channel_id, error_state); |
+ } |
+ socket_record->pending_on_open_callbacks.clear(); |
} |
-void CastSocketService::ShutdownOnUIThread() {} |
+CastSocketService::CastSocketRecord::CastSocketRecord( |
+ std::unique_ptr<CastSocket> socket, |
+ const OnOpenCallback& on_open_callback, |
+ PassThroughMessageHandler* message_handler) |
+ : cast_socket(std::move(socket)), |
+ pass_through_message_handler(message_handler) { |
+ DCHECK(cast_socket); |
+ DCHECK(pass_through_message_handler); |
+ pending_on_open_callbacks.push_back(on_open_callback); |
+} |
+ |
+CastSocketService::CastSocketRecord::~CastSocketRecord() {} |
} // namespace cast_channel |