Chromium Code Reviews| Index: components/cast_channel/cast_socket_service.cc |
| diff --git a/components/cast_channel/cast_socket_service.cc b/components/cast_channel/cast_socket_service.cc |
| index f58c9858e9212121ed2ad48758d411773747d708..b731bf540a4dfddf3b4cce31d3f22a7a234a4202 100644 |
| --- a/components/cast_channel/cast_socket_service.cc |
| +++ b/components/cast_channel/cast_socket_service.cc |
| @@ -5,17 +5,57 @@ |
| #include "components/cast_channel/cast_socket_service.h" |
| #include "base/memory/ptr_util.h" |
| +#include "components/cast_channel/cast_socket.h" |
| +#include "components/cast_channel/cast_transport.h" |
| +#include "components/cast_channel/keep_alive_delegate.h" |
| +#include "components/cast_channel/logger.h" |
| #include "content/public/browser/browser_thread.h" |
| using content::BrowserThread; |
| +namespace { |
| +// Connect timeout for connect calls. |
| +const int kConnectTimeoutSecs = 10; |
| + |
| +// Ping interval |
| +const int kPingIntervalInSecs = 5; |
| + |
| +// Liveness timeout for connect calls, in milliseconds. If no message is |
| +// received from the receiver during LIVENESS_INTERVAL, it is considered gone. |
| +const int kConnectLivenessTimeoutSecs = kPingIntervalInSecs * 2; |
| +} // namespace |
| + |
| namespace cast_channel { |
| +PassThroughMessageHandler::PassThroughMessageHandler() {} |
| + |
| +PassThroughMessageHandler::~PassThroughMessageHandler() {} |
| + |
| +void PassThroughMessageHandler::RegisterDelegate( |
| + std::unique_ptr<CastTransport::Delegate> delegate) { |
| + inner_delegate_ = std::move(delegate); |
| +} |
| + |
| +void PassThroughMessageHandler::OnError(ChannelError error_state) { |
| + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
| + if (inner_delegate_) |
| + inner_delegate_->OnError(error_state); |
| +} |
| + |
| +void PassThroughMessageHandler::OnMessage(const CastMessage& message) { |
| + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
| + if (inner_delegate_) |
| + inner_delegate_->OnMessage(message); |
| +} |
| + |
| +void PassThroughMessageHandler::Start() {} |
| + |
| int CastSocketService::last_channel_id_ = 0; |
| CastSocketService::CastSocketService() |
| : RefcountedKeyedService( |
| - BrowserThread::GetTaskRunnerForThread(BrowserThread::IO)) { |
| + BrowserThread::GetTaskRunnerForThread(BrowserThread::IO)), |
| + logger_(new Logger()) { |
| DETACH_FROM_THREAD(thread_checker_); |
| } |
| @@ -23,35 +63,199 @@ CastSocketService::~CastSocketService() { |
| DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
| } |
| -int CastSocketService::AddSocket(std::unique_ptr<CastSocket> socket) { |
| +std::unique_ptr<CastSocket> CastSocketService::RemoveSocket(int channel_id) { |
| + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
| + DCHECK(channel_id > 0); |
| + auto socket_record_it = socket_records_.find(channel_id); |
| + |
| + std::unique_ptr<CastSocket> socket; |
| + if (socket_record_it != socket_records_.end()) { |
| + auto* socket_record = socket_record_it->second.get(); |
| + socket = std::move(socket_record->cast_socket); |
| + // Invoke all pending OnOpen callbacks. |
| + for (const auto& on_open_callback : |
| + socket_record->pending_on_open_callbacks) { |
| + on_open_callback.Run(socket->id(), ChannelError::CONNECT_ERROR); |
| + } |
| + |
| + socket_records_.erase(socket_record_it); |
| + } |
| + return socket; |
| +} |
| + |
| +CastSocket* CastSocketService::GetSocket(int channel_id) const { |
| + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
| + auto* socket_record = GetSocketRecord(channel_id); |
| + return socket_record ? socket_record->cast_socket.get() : nullptr; |
| +} |
| + |
| +int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint, |
| + ChannelAuthType channel_auth, |
| + net::NetLog* net_log, |
| + const base::TimeDelta& connect_timeout, |
| + const base::TimeDelta& ping_interval, |
| + const base::TimeDelta& liveness_timeout, |
| + const scoped_refptr<Logger>& logger, |
| + uint64_t device_capabilities, |
| + const OnOpenCallback& open_cb) { |
| + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
| + |
| + auto* socket_record = GetSocketRecord(ip_endpoint); |
| + auto* socket = socket_record ? socket_record->cast_socket.get() : nullptr; |
| + // If cast socket exists. |
| + if (socket) { |
| + switch (socket->ready_state()) { |
| + case ReadyState::NONE: |
| + NOTREACHED(); |
| + break; |
| + case ReadyState::CONNECTING: |
| + DCHECK(socket_record->pending_on_open_callbacks.size() > 0); |
| + socket_record->pending_on_open_callbacks.push_back(open_cb); |
| + break; |
| + case ReadyState::OPEN: |
| + open_cb.Run(socket->id(), ChannelError::NONE); |
| + break; |
| + case ReadyState::CLOSING: |
| + NOTREACHED(); |
|
mark a. foltz
2017/06/12 21:14:08
Why is it not possible for the socket to be CLOSIN
zhaobin
2017/06/20 01:37:42
ReadyState::CLOSING is not used anywhere.
|
| + break; |
| + case ReadyState::CLOSED: |
| + RemoveSocket(socket->id()); |
| + socket = nullptr; |
| + break; |
| + default: |
| + NOTREACHED(); |
| + } |
| + if (socket) |
| + return socket->id(); |
| + } |
| + |
| + // If cast socket does not exist. |
| + if (socket_for_test_) { |
| + socket = socket_for_test_.release(); |
| + } else { |
| + socket = new CastSocketImpl( |
| + ip_endpoint, channel_auth, net_log, connect_timeout, |
| + liveness_timeout > base::TimeDelta(), logger, device_capabilities); |
| + } |
| + |
| + PassThroughMessageHandler* pass_through_delegate = |
| + new PassThroughMessageHandler(); |
| + std::unique_ptr<CastTransport::Delegate> delegate(pass_through_delegate); |
|
mark a. foltz
2017/06/12 21:14:08
base::MakeUnique?
zhaobin
2017/06/20 01:37:42
Code removed.
|
| + AddSocketRecord(base::WrapUnique(socket), open_cb, pass_through_delegate); |
| + |
| + if (socket->keep_alive()) { |
| + // Wrap read delegate in a KeepAliveDelegate for timeout handling. |
| + KeepAliveDelegate* keep_alive = new KeepAliveDelegate( |
| + socket, logger_, std::move(delegate), ping_interval, liveness_timeout); |
| + if (injected_timeout_timer_) { |
| + keep_alive->SetTimersForTest(base::MakeUnique<base::Timer>(false, false), |
| + std::move(injected_timeout_timer_)); |
| + } |
| + delegate.reset(keep_alive); |
| + } |
| + |
| + socket->Connect(std::move(delegate), |
| + base::Bind(&CastSocketService::OnOpen, this, socket->id())); |
| + return socket->id(); |
| +} |
| + |
| +int CastSocketService::OpenSocket(const net::IPEndPoint& ip_endpoint, |
| + net::NetLog* net_log, |
| + const OnOpenCallback& open_cb) { |
| + auto connect_timeout = base::TimeDelta::FromSeconds(kConnectTimeoutSecs); |
| + auto ping_interval = base::TimeDelta::FromSeconds(kPingIntervalInSecs); |
| + auto liveness_timeout = |
| + base::TimeDelta::FromSeconds(kConnectLivenessTimeoutSecs); |
| + return OpenSocket(ip_endpoint, ChannelAuthType::SSL_VERIFIED, net_log, |
| + connect_timeout, ping_interval, liveness_timeout, logger_, |
| + CastDeviceCapability::NONE, open_cb); |
| +} |
| + |
| +bool CastSocketService::RegisterDelegate( |
| + int channel_id, |
| + std::unique_ptr<CastTransport::Delegate> delegate) { |
| + DCHECK(channel_id > 0); |
| + auto* socket_record = GetSocketRecord(channel_id); |
| + if (!socket_record) |
| + return false; |
| + |
| + socket_record->pass_through_message_handler->RegisterDelegate( |
|
mark a. foltz
2017/06/12 21:14:08
Why does the CastSocketRecord have a "pass-through
zhaobin
2017/06/20 01:37:42
Code removed. Use observers instead of delegates (
|
| + std::move(delegate)); |
| + return true; |
| +} |
| + |
| +void CastSocketService::SetSocketForTest( |
| + std::unique_ptr<cast_channel::CastSocket> socket_for_test) { |
| + socket_for_test_ = std::move(socket_for_test); |
| +} |
| + |
| +void CastSocketService::SetPingTimeoutTimerForTest( |
| + std::unique_ptr<base::Timer> timer) { |
| + injected_timeout_timer_ = std::move(timer); |
| +} |
| + |
| +void CastSocketService::ShutdownOnUIThread() {} |
| + |
| +int CastSocketService::AddSocketRecord( |
| + std::unique_ptr<CastSocket> socket, |
| + const OnOpenCallback& on_open_callback, |
| + PassThroughMessageHandler* message_handler) { |
| DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
| DCHECK(socket); |
| int id = ++last_channel_id_; |
| socket->set_id(id); |
| - sockets_.insert(std::make_pair(id, std::move(socket))); |
| + socket_records_.insert(std::make_pair( |
| + id, base::MakeUnique<CastSocketRecord>( |
| + std::move(socket), on_open_callback, message_handler))); |
| return id; |
| } |
| -std::unique_ptr<CastSocket> CastSocketService::RemoveSocket(int channel_id) { |
| +CastSocketService::CastSocketRecord* CastSocketService::GetSocketRecord( |
| + int channel_id) const { |
| DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
| DCHECK(channel_id > 0); |
| - auto socket_it = sockets_.find(channel_id); |
| + const auto& socket_record_it = socket_records_.find(channel_id); |
| + return socket_record_it == socket_records_.end() |
| + ? nullptr |
| + : socket_record_it->second.get(); |
| +} |
| - std::unique_ptr<CastSocket> socket; |
| - if (socket_it != sockets_.end()) { |
| - socket = std::move(socket_it->second); |
| - sockets_.erase(socket_it); |
| +CastSocketService::CastSocketRecord* CastSocketService::GetSocketRecord( |
| + const net::IPEndPoint& ip_endpoint) const { |
| + DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
| + for (const auto& socket_record_it : socket_records_) { |
|
mark a. foltz
2017/06/12 21:14:08
Consider using std::find_if
zhaobin
2017/06/20 01:37:42
Done.
|
| + auto* socket_record = socket_record_it.second.get(); |
| + if (socket_record->cast_socket->ip_endpoint() == ip_endpoint) |
| + return socket_record; |
| } |
| - return socket; |
| + return nullptr; |
| } |
| -CastSocket* CastSocketService::GetSocket(int channel_id) const { |
| +void CastSocketService::OnOpen(int channel_id, ChannelError error_state) { |
|
mark a. foltz
2017/06/12 21:14:08
It might make more sense for OnOpen to be declared
zhaobin
2017/06/20 01:37:42
Code removed. Let CastSocket track OnOpen callback
|
| DCHECK_CALLED_ON_VALID_THREAD(thread_checker_); |
| - DCHECK(channel_id > 0); |
| - const auto& socket_it = sockets_.find(channel_id); |
| - return socket_it == sockets_.end() ? nullptr : socket_it->second.get(); |
| + auto* socket_record = GetSocketRecord(channel_id); |
| + if (!socket_record) |
| + return; |
| + |
| + // Invoke all pending OnOpen callbacks. |
| + for (const auto& on_open_callback : |
| + socket_record->pending_on_open_callbacks) { |
| + on_open_callback.Run(channel_id, error_state); |
| + } |
| + socket_record->pending_on_open_callbacks.clear(); |
| } |
| -void CastSocketService::ShutdownOnUIThread() {} |
| +CastSocketService::CastSocketRecord::CastSocketRecord( |
| + std::unique_ptr<CastSocket> socket, |
| + const OnOpenCallback& on_open_callback, |
| + PassThroughMessageHandler* message_handler) |
| + : cast_socket(std::move(socket)), |
| + pass_through_message_handler(message_handler) { |
| + DCHECK(cast_socket); |
| + DCHECK(pass_through_message_handler); |
| + pending_on_open_callbacks.push_back(on_open_callback); |
| +} |
| + |
| +CastSocketService::CastSocketRecord::~CastSocketRecord() {} |
| } // namespace cast_channel |