Index: extensions/browser/api/cast_channel/cast_channel_api.cc |
diff --git a/extensions/browser/api/cast_channel/cast_channel_api.cc b/extensions/browser/api/cast_channel/cast_channel_api.cc |
index ee888cb2602536cfbea8f03a134ce05f7e9d6b43..9d711a48c0c7910804a435ce46cd76f4fa3ae8c6 100644 |
--- a/extensions/browser/api/cast_channel/cast_channel_api.cc |
+++ b/extensions/browser/api/cast_channel/cast_channel_api.cc |
@@ -131,27 +131,14 @@ CastChannelAPI::GetFactoryInstance() { |
return g_factory.Pointer(); |
} |
-scoped_ptr<CastSocket> CastChannelAPI::CreateCastSocket( |
- const std::string& extension_id, const net::IPEndPoint& ip_endpoint, |
- ChannelAuthType channel_auth, const base::TimeDelta& timeout) { |
- if (socket_for_test_.get()) { |
- return socket_for_test_.Pass(); |
- } else { |
- return scoped_ptr<CastSocket>( |
- new CastSocket(extension_id, |
- ip_endpoint, |
- channel_auth, |
- this, |
- ExtensionsBrowserClient::Get()->GetNetLog(), |
- timeout, |
- logger_)); |
- } |
-} |
- |
void CastChannelAPI::SetSocketForTest(scoped_ptr<CastSocket> socket_for_test) { |
socket_for_test_ = socket_for_test.Pass(); |
} |
+scoped_ptr<cast_channel::CastSocket> CastChannelAPI::GetSocketForTest() { |
+ return socket_for_test_.Pass(); |
+} |
+ |
void CastChannelAPI::OnError(const CastSocket* socket, |
cast_channel::ChannelError error_state, |
const cast_channel::LastErrors& last_errors) { |
@@ -367,13 +354,19 @@ bool CastChannelOpenFunction::Prepare() { |
void CastChannelOpenFunction::AsyncWorkStart() { |
DCHECK(api_); |
DCHECK(ip_endpoint_.get()); |
- scoped_ptr<CastSocket> socket = api_->CreateCastSocket( |
- extension_->id(), |
- *ip_endpoint_, |
- channel_auth_, |
- base::TimeDelta::FromMilliseconds(connect_info_->timeout.get() |
- ? *connect_info_->timeout |
- : kDefaultConnectTimeoutMillis)); |
+ scoped_ptr<CastSocket> socket = api_->GetSocketForTest(); |
+ if (!socket.get()) { |
+ socket.reset(new CastSocket( |
+ extension_->id(), |
+ *ip_endpoint_, |
+ channel_auth_, |
+ api_, |
+ ExtensionsBrowserClient::Get()->GetNetLog(), |
+ base::TimeDelta::FromMilliseconds(connect_info_->timeout.get() |
+ ? *connect_info_->timeout |
+ : kDefaultConnectTimeoutMillis), |
+ api_->GetLogger())); |
+ } |
new_channel_id_ = AddSocket(socket.release()); |
CastSocket* new_socket = GetSocket(new_channel_id_); |
api_->GetLogger()->LogNewSocketEvent(*new_socket); |
@@ -424,11 +417,15 @@ bool CastChannelSendFunction::Prepare() { |
} |
void CastChannelSendFunction::AsyncWorkStart() { |
- CastSocket* socket = GetSocketOrCompleteWithError( |
- params_->channel.channel_id); |
- if (socket) |
- socket->SendMessage(params_->message, |
- base::Bind(&CastChannelSendFunction::OnSend, this)); |
+ CastSocket* socket = GetSocket(params_->channel.channel_id); |
+ if (!socket) { |
+ SetResultFromError(params_->channel.channel_id, |
+ cast_channel::CHANNEL_ERROR_INVALID_CHANNEL_ID); |
+ AsyncWorkCompleted(); |
+ return; |
+ } |
+ socket->SendMessage(params_->message, |
+ base::Bind(&CastChannelSendFunction::OnSend, this)); |
} |
void CastChannelSendFunction::OnSend(int result) { |
@@ -455,10 +452,14 @@ bool CastChannelCloseFunction::Prepare() { |
} |
void CastChannelCloseFunction::AsyncWorkStart() { |
- CastSocket* socket = GetSocketOrCompleteWithError( |
- params_->channel.channel_id); |
- if (socket) |
+ CastSocket* socket = GetSocket(params_->channel.channel_id); |
+ if (!socket) { |
+ SetResultFromError(params_->channel.channel_id, |
+ cast_channel::CHANNEL_ERROR_INVALID_CHANNEL_ID); |
+ AsyncWorkCompleted(); |
+ } else { |
socket->Close(base::Bind(&CastChannelCloseFunction::OnClose, this)); |
+ } |
} |
void CastChannelCloseFunction::OnClose(int result) { |