| Index: chrome/browser/extensions/api/cast_channel/cast_socket.cc
|
| ===================================================================
|
| --- chrome/browser/extensions/api/cast_channel/cast_socket.cc (revision 230522)
|
| +++ chrome/browser/extensions/api/cast_channel/cast_socket.cc (working copy)
|
| @@ -11,6 +11,7 @@
|
| #include "base/lazy_instance.h"
|
| #include "base/strings/string_number_conversions.h"
|
| #include "base/sys_byteorder.h"
|
| +#include "chrome/browser/extensions/api/cast_channel/cast_auth_util.h"
|
| #include "chrome/browser/extensions/api/cast_channel/cast_channel.pb.h"
|
| #include "chrome/browser/extensions/api/cast_channel/cast_message_util.h"
|
| #include "net/base/address_list.h"
|
| @@ -65,13 +66,14 @@
|
| const uint32 kMaxMessageSize = 65536;
|
|
|
| CastSocket::CastSocket(const std::string& owner_extension_id,
|
| - const GURL& url, CastSocket::Delegate* delegate,
|
| + const GURL& url,
|
| + CastSocket::Delegate* delegate,
|
| net::NetLog* net_log) :
|
| ApiResource(owner_extension_id),
|
| channel_id_(0),
|
| url_(url),
|
| delegate_(delegate),
|
| - is_secure_(false),
|
| + auth_required_(false),
|
| error_state_(CHANNEL_ERROR_NONE),
|
| ready_state_(READY_STATE_NONE),
|
| write_callback_pending_(false),
|
| @@ -97,18 +99,6 @@
|
| return url_;
|
| }
|
|
|
| -bool CastSocket::ExtractPeerCert(std::string* cert) {
|
| - CHECK(peer_cert_.empty());
|
| - net::SSLInfo ssl_info;
|
| - if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get())
|
| - return false;
|
| - bool result = net::X509Certificate::GetDEREncoded(
|
| - ssl_info.cert->os_cert_handle(), cert);
|
| - if (result)
|
| - DVLOG(1) << "Successfully extracted peer certificate: " << *cert;
|
| - return result;
|
| -}
|
| -
|
| scoped_ptr<net::TCPClientSocket> CastSocket::CreateTcpSocket() {
|
| net::AddressList addresses(ip_endpoint_);
|
| scoped_ptr<net::TCPClientSocket> tcp_socket(
|
| @@ -146,12 +136,45 @@
|
| connection.Pass(), host_and_port, ssl_config, context);
|
| }
|
|
|
| +bool CastSocket::ExtractPeerCert(std::string* cert) {
|
| + DCHECK(cert);
|
| + DCHECK(peer_cert_.empty());
|
| + net::SSLInfo ssl_info;
|
| + if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get())
|
| + return false;
|
| + bool result = net::X509Certificate::GetDEREncoded(
|
| + ssl_info.cert->os_cert_handle(), cert);
|
| + if (result)
|
| + DVLOG(1) << "Successfully extracted peer certificate: " << *cert;
|
| + return result;
|
| +}
|
| +
|
| +int CastSocket::SendAuthChallenge() {
|
| + CastMessage challenge_message;
|
| + CreateAuthChallengeMessage(&challenge_message);
|
| + DVLOG(1) << "Sending challenge: " << CastMessageToString(challenge_message);
|
| + return SendMessageInternal(
|
| + challenge_message,
|
| + base::Bind(&CastSocket::OnChallengeEvent, AsWeakPtr()));
|
| +}
|
| +
|
| +int CastSocket::ReadAuthChallengeReply() {
|
| + return ReadData();
|
| +}
|
| +
|
| void CastSocket::OnConnectComplete(int result) {
|
| int rv = DoConnectLoop(result);
|
| if (rv != net::ERR_IO_PENDING)
|
| DoConnectCallback(rv);
|
| }
|
|
|
| +void CastSocket::OnChallengeEvent(int result) {
|
| + // result >= 0 means read or write succeeded synchronously.
|
| + int rv = DoConnectLoop(result >= 0 ? net::OK : result);
|
| + if (rv != net::ERR_IO_PENDING)
|
| + DoConnectCallback(rv);
|
| +}
|
| +
|
| void CastSocket::Connect(const net::CompletionCallback& callback) {
|
| DCHECK(CalledOnValidThread());
|
| int result = net::ERR_CONNECTION_FAILED;
|
| @@ -162,7 +185,6 @@
|
| }
|
| if (!ParseChannelUrl(url_)) {
|
| CloseWithError(cast_channel::CHANNEL_ERROR_CONNECT_ERROR);
|
| - // TODO(mfoltz): Signal channel errors via |callback|
|
| callback.Run(result);
|
| return;
|
| }
|
| @@ -205,6 +227,16 @@
|
| case CONN_STATE_SSL_CONNECT_COMPLETE:
|
| rv = DoSslConnectComplete(rv);
|
| break;
|
| + case CONN_STATE_AUTH_CHALLENGE_SEND:
|
| + rv = DoAuthChallengeSend();
|
| + break;
|
| + case CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE:
|
| + rv = DoAuthChallengeSendComplete(rv);
|
| + break;
|
| + case CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE:
|
| + rv = DoAuthChallengeReplyComplete(rv);
|
| + break;
|
| +
|
| default:
|
| NOTREACHED() << "BUG in CastSocket state machine code";
|
| break;
|
| @@ -218,6 +250,7 @@
|
| }
|
|
|
| int CastSocket::DoTcpConnect() {
|
| + DVLOG(1) << "DoTcpConnect";
|
| next_state_ = CONN_STATE_TCP_CONNECT_COMPLETE;
|
| tcp_socket_ = CreateTcpSocket();
|
| return tcp_socket_->Connect(
|
| @@ -225,12 +258,14 @@
|
| }
|
|
|
| int CastSocket::DoTcpConnectComplete(int result) {
|
| + DVLOG(1) << "DoTcpConnectComplete: " << result;
|
| if (result == net::OK)
|
| next_state_ = CONN_STATE_SSL_CONNECT;
|
| return result;
|
| }
|
|
|
| int CastSocket::DoSslConnect() {
|
| + DVLOG(1) << "DoSslConnect";
|
| next_state_ = CONN_STATE_SSL_CONNECT_COMPLETE;
|
| socket_ = CreateSslSocket();
|
| return socket_->Connect(
|
| @@ -238,15 +273,45 @@
|
| }
|
|
|
| int CastSocket::DoSslConnectComplete(int result) {
|
| - // TODO(mfoltz,munjal): Authenticate the channel if is_secure_ == true.
|
| + DVLOG(1) << "DoSslConnectComplete: " << result;
|
| if (result == net::ERR_CERT_AUTHORITY_INVALID &&
|
| peer_cert_.empty() &&
|
| ExtractPeerCert(&peer_cert_)) {
|
| next_state_ = CONN_STATE_TCP_CONNECT;
|
| + } else if (result == net::OK && auth_required_) {
|
| + next_state_ = CONN_STATE_AUTH_CHALLENGE_SEND;
|
| }
|
| return result;
|
| }
|
|
|
| +int CastSocket::DoAuthChallengeSend() {
|
| + DVLOG(1) << "DoAuthChallengeSend";
|
| + next_state_ = CONN_STATE_AUTH_CHALLENGE_SEND_COMPLETE;
|
| + return SendAuthChallenge();
|
| +}
|
| +
|
| +int CastSocket::DoAuthChallengeSendComplete(int result) {
|
| + DVLOG(1) << "DoAuthChallengeSendComplete: " << result;
|
| + if (result != net::OK)
|
| + return result;
|
| + next_state_ = CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE;
|
| + return ReadAuthChallengeReply();
|
| +}
|
| +
|
| +int CastSocket::DoAuthChallengeReplyComplete(int result) {
|
| + DVLOG(1) << "DoAuthChallengeReplyComplete: " << result;
|
| + if (result != net::OK)
|
| + return result;
|
| + if (!VerifyChallengeReply())
|
| + return net::ERR_FAILED;
|
| + DVLOG(1) << "Auth challenge verification succeeded";
|
| + return net::OK;
|
| +}
|
| +
|
| +bool CastSocket::VerifyChallengeReply() {
|
| + return AuthenticateChallengeReply(*challenge_reply_.get(), peer_cert_);
|
| +}
|
| +
|
| void CastSocket::DoConnectCallback(int result) {
|
| ready_state_ = (result == net::OK) ? READY_STATE_OPEN : READY_STATE_CLOSED;
|
| error_state_ = (result == net::OK) ?
|
| @@ -276,31 +341,40 @@
|
| callback.Run(result);
|
| return;
|
| }
|
| - WriteRequest write_request(callback);
|
| CastMessage message_proto;
|
| - if (!MessageInfoToCastMessage(message, &message_proto) ||
|
| - !write_request.SetContent(message_proto)) {
|
| + if (!MessageInfoToCastMessage(message, &message_proto)) {
|
| CloseWithError(cast_channel::CHANNEL_ERROR_INVALID_MESSAGE);
|
| // TODO(mfoltz): Do a better job of signaling cast_channel errors to the
|
| // caller.
|
| callback.Run(net::OK);
|
| return;
|
| }
|
| + SendMessageInternal(message_proto, callback);
|
| + /*
|
| + if (result >= 0) {
|
| + callback.Run(result);
|
| + } else if (result != net::ERR_IO_PENDING && result != net::OK) {
|
| + CloseWithError(cast_channel::CHANNEL_ERROR_INVALID_MESSAGE);
|
| + callback.Run(net::ERR_FAILED);
|
| + }*/
|
| +}
|
| +
|
| +int CastSocket::SendMessageInternal(const CastMessage& message_proto,
|
| + const net::CompletionCallback& callback) {
|
| + WriteRequest write_request(callback);
|
| + if (!write_request.SetContent(message_proto))
|
| + return net::ERR_FAILED;
|
| write_queue_.push(write_request);
|
| - WriteData();
|
| + return WriteData();
|
| }
|
|
|
| -void CastSocket::WriteData() {
|
| +int CastSocket::WriteData() {
|
| DCHECK(CalledOnValidThread());
|
| DVLOG(1) << "WriteData q = " << write_queue_.size();
|
| if (write_queue_.empty() || write_callback_pending_)
|
| - return;
|
| + return net::ERR_FAILED;
|
|
|
| WriteRequest& request = write_queue_.front();
|
| - if (ready_state_ != READY_STATE_OPEN) {
|
| - request.callback.Run(net::ERR_FAILED);
|
| - return;
|
| - }
|
|
|
| DVLOG(1) << "WriteData byte_count = " << request.io_buffer->size() <<
|
| " bytes_written " << request.io_buffer->BytesConsumed();
|
| @@ -311,10 +385,10 @@
|
| request.io_buffer->BytesRemaining(),
|
| base::Bind(&CastSocket::OnWriteData, AsWeakPtr()));
|
|
|
| - DVLOG(1) << "WriteData result = " << result;
|
| -
|
| if (result != net::ERR_IO_PENDING)
|
| OnWriteData(result);
|
| +
|
| + return result;
|
| }
|
|
|
| void CastSocket::OnWriteData(int result) {
|
| @@ -358,11 +432,10 @@
|
| WriteData();
|
| }
|
|
|
| -void CastSocket::ReadData() {
|
| +int CastSocket::ReadData() {
|
| DCHECK(CalledOnValidThread());
|
| - if (!socket_.get() || ready_state_ != READY_STATE_OPEN) {
|
| - return;
|
| - }
|
| + if (!socket_.get())
|
| + return net::ERR_FAILED;
|
| DCHECK(!read_callback_pending_);
|
| read_callback_pending_ = true;
|
| // Figure out if we are reading the header or body, and the remaining bytes.
|
| @@ -389,13 +462,14 @@
|
| } else if (result != net::ERR_IO_PENDING) {
|
| CloseWithError(CHANNEL_ERROR_SOCKET_ERROR);
|
| }
|
| + return result;
|
| }
|
|
|
| void CastSocket::OnReadData(int result) {
|
| DCHECK(CalledOnValidThread());
|
| - DVLOG(1) << "OnReadData result = " << result <<
|
| - " header offset = " << header_read_buffer_->offset() <<
|
| - " body offset = " << body_read_buffer_->offset();
|
| + DVLOG(1) << "OnReadData result = " << result
|
| + << " header offset = " << header_read_buffer_->offset()
|
| + << " body offset = " << body_read_buffer_->offset();
|
| read_callback_pending_ = false;
|
| if (result <= 0) {
|
| CloseWithError(CHANNEL_ERROR_SOCKET_ERROR);
|
| @@ -458,8 +532,12 @@
|
| body_read_buffer_->StartOfBuffer(),
|
| current_message_size_))
|
| return false;
|
| - DVLOG(1) << "Parsed message " << MessageProtoToString(message_proto);
|
| - if (delegate_) {
|
| + DVLOG(1) << "Parsed message " << CastMessageToString(message_proto);
|
| + // If the message is an auth message then we handle it internally.
|
| + if (IsAuthMessage(message_proto)) {
|
| + challenge_reply_.reset(new CastMessage(message_proto));
|
| + OnChallengeEvent(net::OK);
|
| + } else if (delegate_) {
|
| MessageInfo message;
|
| if (!CastMessageToMessageInfo(message_proto, &message))
|
| return false;
|
| @@ -496,9 +574,9 @@
|
| bool CastSocket::ParseChannelUrl(const GURL& url) {
|
| DVLOG(1) << "url = " + url.spec();
|
| if (url.SchemeIs(kCastInsecureScheme)) {
|
| - is_secure_ = false;
|
| + auth_required_ = false;
|
| } else if (url.SchemeIs(kCastSecureScheme)) {
|
| - is_secure_ = true;
|
| + auth_required_ = true;
|
| } else {
|
| return false;
|
| }
|
|
|