Index: chrome/browser/extensions/api/cast_channel/cast_socket.cc |
=================================================================== |
--- chrome/browser/extensions/api/cast_channel/cast_socket.cc (revision 230132) |
+++ chrome/browser/extensions/api/cast_channel/cast_socket.cc (working copy) |
@@ -11,6 +11,7 @@ |
#include "base/lazy_instance.h" |
#include "base/strings/string_number_conversions.h" |
#include "base/sys_byteorder.h" |
+#include "chrome/browser/extensions/api/cast_channel/cast_auth_util.h" |
#include "chrome/browser/extensions/api/cast_channel/cast_channel.pb.h" |
#include "chrome/browser/extensions/api/cast_channel/cast_message_util.h" |
#include "net/base/address_list.h" |
@@ -65,13 +66,14 @@ |
const uint32 kMaxMessageSize = 65536; |
CastSocket::CastSocket(const std::string& owner_extension_id, |
- const GURL& url, CastSocket::Delegate* delegate, |
+ const GURL& url, |
+ CastSocket::Delegate* delegate, |
net::NetLog* net_log) : |
ApiResource(owner_extension_id), |
channel_id_(0), |
url_(url), |
delegate_(delegate), |
- is_secure_(false), |
+ auth_required_(false), |
error_state_(CHANNEL_ERROR_NONE), |
ready_state_(READY_STATE_NONE), |
write_callback_pending_(false), |
@@ -97,18 +99,6 @@ |
return url_; |
} |
-bool CastSocket::ExtractPeerCert(std::string* cert) { |
- CHECK(peer_cert_.empty()); |
- net::SSLInfo ssl_info; |
- if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get()) |
- return false; |
- bool result = net::X509Certificate::GetDEREncoded( |
- ssl_info.cert->os_cert_handle(), cert); |
- if (result) |
- DVLOG(1) << "Successfully extracted peer certificate: " << *cert; |
- return result; |
-} |
- |
scoped_ptr<net::TCPClientSocket> CastSocket::CreateTcpSocket() { |
net::AddressList addresses(ip_endpoint_); |
scoped_ptr<net::TCPClientSocket> tcp_socket( |
@@ -146,12 +136,43 @@ |
connection.Pass(), host_and_port, ssl_config, context); |
} |
+bool CastSocket::ExtractPeerCert(std::string* cert) { |
+ DCHECK(cert); |
+ DCHECK(peer_cert_.empty()); |
+ net::SSLInfo ssl_info; |
+ if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get()) |
+ return false; |
+ bool result = net::X509Certificate::GetDEREncoded( |
+ ssl_info.cert->os_cert_handle(), cert); |
+ if (result) |
+ DVLOG(1) << "Successfully extracted peer certificate: " << *cert; |
+ return result; |
+} |
+ |
+int CastSocket::SendAuthChallenge() { |
+ CastMessage challenge_message; |
+ CreateAuthChallengeMessage(&challenge_message); |
+ return SendMessageInternal( |
+ challenge_message, |
+ base::Bind(&CastSocket::OnChallengeEvent, AsWeakPtr())); |
+} |
+ |
+int CastSocket::ReadAuthChallengeReply() { |
+ return ReadData(); |
+} |
+ |
void CastSocket::OnConnectComplete(int result) { |
int rv = DoConnectLoop(result); |
if (rv != net::ERR_IO_PENDING) |
DoConnectCallback(rv); |
} |
+void CastSocket::OnChallengeEvent(int result) { |
+ int rv = DoConnectLoop(result); |
+ if (rv != net::ERR_IO_PENDING) |
+ DoConnectCallback(rv); |
+} |
+ |
void CastSocket::Connect(const net::CompletionCallback& callback) { |
DCHECK(CalledOnValidThread()); |
int result = net::ERR_CONNECTION_FAILED; |
@@ -162,7 +183,6 @@ |
} |
if (!ParseChannelUrl(url_)) { |
CloseWithError(cast_channel::CHANNEL_ERROR_CONNECT_ERROR); |
- // TODO(mfoltz): Signal channel errors via |callback| |
callback.Run(result); |
return; |
} |
@@ -205,6 +225,16 @@ |
case CONN_STATE_SSL_CONNECT_COMPLETE: |
rv = DoSslConnectComplete(rv); |
break; |
+ case CONN_STATE_AUTH_CHALLENGE_SEND: |
+ rv = DoAuthChallengeSend(); |
+ break; |
+ case CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE: |
+ rv = DoAuthChallengeSendComplete(rv); |
+ break; |
+ case CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE: |
+ rv = DoAuthChallengeReplyComplete(rv); |
+ break; |
+ |
default: |
NOTREACHED() << "BUG in CastSocket state machine code"; |
break; |
@@ -238,15 +268,40 @@ |
} |
int CastSocket::DoSslConnectComplete(int result) { |
- // TODO(mfoltz,munjal): Authenticate the channel if is_secure_ == true. |
if (result == net::ERR_CERT_AUTHORITY_INVALID && |
peer_cert_.empty() && |
ExtractPeerCert(&peer_cert_)) { |
next_state_ = CONN_STATE_TCP_CONNECT; |
+ } else if (result == net::OK && auth_required_) { |
+ next_state_ = CONN_STATE_AUTH_CHALLENGE_SEND; |
} |
return result; |
} |
+int CastSocket::DoAuthChallengeSend() { |
+ next_state_ = CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE; |
+ return SendAuthChallenge(); |
+} |
+ |
+int CastSocket::DoAuthChallengeSendComplete(int result) { |
+ if (result != net::OK) |
+ return result; |
+ next_state_ = CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE; |
+ return ReadAuthChallengeReply(); |
+} |
+ |
+int CastSocket::DoAuthChallengeReplyComplete(int result) { |
+ if (result != net::OK) |
+ return result; |
+ if (!VerifyChallengeReply()) |
+ return net::ERR_FAILED; |
+ return net::OK; |
+} |
+ |
+bool CastSocket::VerifyChallengeReply() { |
+ return AuthenticateChallengeReply(*challenge_reply_.get(), peer_cert_); |
+} |
+ |
void CastSocket::DoConnectCallback(int result) { |
ready_state_ = (result == net::OK) ? READY_STATE_OPEN : READY_STATE_CLOSED; |
error_state_ = (result == net::OK) ? |
@@ -276,31 +331,38 @@ |
callback.Run(result); |
return; |
} |
- WriteRequest write_request(callback); |
CastMessage message_proto; |
- if (!MessageInfoToCastMessage(message, &message_proto) || |
- !write_request.SetContent(message_proto)) { |
+ if (!MessageInfoToCastMessage(message, &message_proto)) { |
CloseWithError(cast_channel::CHANNEL_ERROR_INVALID_MESSAGE); |
// TODO(mfoltz): Do a better job of signaling cast_channel errors to the |
// caller. |
callback.Run(net::OK); |
return; |
} |
+ result = SendMessageInternal(message_proto, callback); |
+ if (result == net::ERR_FAILED || result == net::OK) |
Ryan Sleevi
2013/10/23 23:28:51
BUG: You fail to handle a number of cases here.
S
Munjal (Google)
2013/10/24 00:00:56
Done.
|
+ { |
Ryan Sleevi
2013/10/23 23:28:51
style: brace goes on line above
Munjal (Google)
2013/10/24 00:00:56
Done.
|
+ CloseWithError(cast_channel::CHANNEL_ERROR_INVALID_MESSAGE); |
+ callback.Run(net::ERR_FAILED); |
+ } |
+} |
+ |
+int CastSocket::SendMessageInternal(const CastMessage& message_proto, |
+ const net::CompletionCallback& callback) { |
+ WriteRequest write_request(callback); |
+ if (!write_request.SetContent(message_proto)) |
+ return net::ERR_FAILED; |
write_queue_.push(write_request); |
- WriteData(); |
+ return WriteData(); |
Ryan Sleevi
2013/10/23 23:28:51
BUG: Potentially bad recursion design
SendMessage
Munjal (Google)
2013/10/24 00:00:56
The recursive nature of Write was already there in
|
} |
-void CastSocket::WriteData() { |
+int CastSocket::WriteData() { |
DCHECK(CalledOnValidThread()); |
DVLOG(1) << "WriteData q = " << write_queue_.size(); |
if (write_queue_.empty() || write_callback_pending_) |
- return; |
+ return net::ERR_FAILED; |
WriteRequest& request = write_queue_.front(); |
- if (ready_state_ != READY_STATE_OPEN) { |
- request.callback.Run(net::ERR_FAILED); |
- return; |
- } |
DVLOG(1) << "WriteData byte_count = " << request.io_buffer->size() << |
" bytes_written " << request.io_buffer->BytesConsumed(); |
@@ -311,10 +373,10 @@ |
request.io_buffer->BytesRemaining(), |
base::Bind(&CastSocket::OnWriteData, AsWeakPtr())); |
- DVLOG(1) << "WriteData result = " << result; |
- |
if (result != net::ERR_IO_PENDING) |
OnWriteData(result); |
+ |
+ return result; |
} |
void CastSocket::OnWriteData(int result) { |
@@ -358,11 +420,10 @@ |
WriteData(); |
} |
-void CastSocket::ReadData() { |
+int CastSocket::ReadData() { |
DCHECK(CalledOnValidThread()); |
- if (!socket_.get() || ready_state_ != READY_STATE_OPEN) { |
- return; |
- } |
+ if (!socket_.get()) |
+ return net::ERR_FAILED; |
DCHECK(!read_callback_pending_); |
read_callback_pending_ = true; |
// Figure out if we are reading the header or body, and the remaining bytes. |
@@ -389,13 +450,14 @@ |
} else if (result != net::ERR_IO_PENDING) { |
CloseWithError(CHANNEL_ERROR_SOCKET_ERROR); |
} |
+ return result; |
} |
void CastSocket::OnReadData(int result) { |
DCHECK(CalledOnValidThread()); |
- DVLOG(1) << "OnReadData result = " << result << |
- " header offset = " << header_read_buffer_->offset() << |
- " body offset = " << body_read_buffer_->offset(); |
+ DVLOG(1) << "OnReadData result = " << result |
+ << " header offset = " << header_read_buffer_->offset() |
+ << " body offset = " << body_read_buffer_->offset(); |
read_callback_pending_ = false; |
if (result <= 0) { |
CloseWithError(CHANNEL_ERROR_SOCKET_ERROR); |
@@ -458,8 +520,12 @@ |
body_read_buffer_->StartOfBuffer(), |
current_message_size_)) |
return false; |
- DVLOG(1) << "Parsed message " << MessageProtoToString(message_proto); |
- if (delegate_) { |
+ DVLOG(1) << "Parsed message " << CastMessageToString(message_proto); |
+ // If the message is an auth message then we handle it internally. |
+ if (IsAuthMessage(message_proto)) { |
+ challenge_reply_.reset(new CastMessage(message_proto)); |
+ OnChallengeEvent(net::OK); |
+ } else if (delegate_) { |
MessageInfo message; |
if (!CastMessageToMessageInfo(message_proto, &message)) |
return false; |
@@ -496,9 +562,9 @@ |
bool CastSocket::ParseChannelUrl(const GURL& url) { |
DVLOG(1) << "url = " + url.spec(); |
if (url.SchemeIs(kCastInsecureScheme)) { |
- is_secure_ = false; |
+ auth_required_ = false; |
} else if (url.SchemeIs(kCastSecureScheme)) { |
- is_secure_ = true; |
+ auth_required_ = true; |
} else { |
return false; |
} |