Index: extensions/browser/api/cast_channel/cast_socket.cc |
diff --git a/extensions/browser/api/cast_channel/cast_socket.cc b/extensions/browser/api/cast_channel/cast_socket.cc |
index dac6e0d395b84e113b6811912310f1f3c6219ba0..617a53730d1025f4e4e664f173620b6c5d40711e 100644 |
--- a/extensions/browser/api/cast_channel/cast_socket.cc |
+++ b/extensions/browser/api/cast_channel/cast_socket.cc |
@@ -18,6 +18,7 @@ |
#include "extensions/browser/api/cast_channel/cast_auth_util.h" |
#include "extensions/browser/api/cast_channel/cast_channel.pb.h" |
#include "extensions/browser/api/cast_channel/cast_message_util.h" |
+#include "extensions/browser/api/cast_channel/cast_socket_framer.h" |
#include "extensions/browser/api/cast_channel/logger.h" |
#include "extensions/browser/api/cast_channel/logger_util.h" |
#include "net/base/address_list.h" |
@@ -47,7 +48,6 @@ namespace { |
// after 9 failed probes. So the total idle time before close is 10 * |
// kTcpKeepAliveDelaySecs. |
const int kTcpKeepAliveDelaySecs = 10; |
- |
} // namespace |
namespace extensions { |
@@ -189,8 +189,6 @@ CastSocket::CastSocket(const std::string& owner_extension_id, |
ip_endpoint_(ip_endpoint), |
channel_auth_(channel_auth), |
delegate_(delegate), |
- current_message_size_(0), |
- current_message_(new CastMessage()), |
net_log_(net_log), |
logger_(logger), |
connect_timeout_(timeout), |
@@ -207,12 +205,10 @@ CastSocket::CastSocket(const std::string& owner_extension_id, |
net_log_source_.type = net::NetLog::SOURCE_SOCKET; |
net_log_source_.id = net_log_->NextID(); |
- // Reuse these buffers for each message. |
- header_read_buffer_ = new net::GrowableIOBuffer(); |
- header_read_buffer_->SetCapacity(MessageHeader::header_size()); |
- body_read_buffer_ = new net::GrowableIOBuffer(); |
- body_read_buffer_->SetCapacity(MessageHeader::max_message_size()); |
- current_read_buffer_ = header_read_buffer_; |
+ // Buffer is reused across messages. |
+ read_buffer_ = new net::GrowableIOBuffer(); |
+ read_buffer_->SetCapacity(kMaxMessageSizeBytes); |
+ framer_.reset(new MessageFramer(read_buffer_)); |
} |
CastSocket::~CastSocket() { |
@@ -272,16 +268,18 @@ bool CastSocket::ExtractPeerCert(std::string* cert) { |
DCHECK(cert); |
DCHECK(peer_cert_.empty()); |
net::SSLInfo ssl_info; |
- if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get()) |
+ if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get()) { |
return false; |
+ } |
logger_->LogSocketEvent(channel_id_, proto::SSL_INFO_OBTAINED); |
bool result = net::X509Certificate::GetDEREncoded( |
ssl_info.cert->os_cert_handle(), cert); |
- if (result) |
+ if (result) { |
VLOG_WITH_CONNECTION(1) << "Successfully extracted peer certificate: " |
<< *cert; |
+ } |
logger_->LogSocketEventWithRv( |
channel_id_, proto::DER_ENCODED_CERT_OBTAIN, result ? 1 : 0); |
@@ -488,8 +486,9 @@ void CastSocket::DoAuthChallengeSendWriteComplete(int result) { |
int CastSocket::DoAuthChallengeSendComplete(int result) { |
VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSendComplete: " << result; |
- if (result < 0) |
+ if (result < 0) { |
return result; |
+ } |
SetConnectState(CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE); |
// Post a task to start read loop so that DoReadLoop is not nested inside |
@@ -502,10 +501,12 @@ int CastSocket::DoAuthChallengeSendComplete(int result) { |
int CastSocket::DoAuthChallengeReplyComplete(int result) { |
VLOG_WITH_CONNECTION(1) << "DoAuthChallengeReplyComplete: " << result; |
- if (result < 0) |
+ if (result < 0) { |
return result; |
- if (!VerifyChallengeReply()) |
+ } |
+ if (!VerifyChallengeReply()) { |
return net::ERR_FAILED; |
+ } |
VLOG_WITH_CONNECTION(1) << "Auth challenge verification succeeded"; |
return net::OK; |
} |
@@ -666,12 +667,14 @@ void CastSocket::DoWriteLoop(int result) { |
// If write loop is done because the queue is empty then set write |
// state to NONE |
- if (write_queue_.empty()) |
+ if (write_queue_.empty()) { |
SetWriteState(WRITE_STATE_NONE); |
+ } |
// Write loop is done - if the result is ERR_FAILED then close with error. |
- if (rv == net::ERR_FAILED) |
+ if (rv == net::ERR_FAILED) { |
CloseWithError(); |
+ } |
} |
int CastSocket::DoWrite() { |
@@ -705,10 +708,11 @@ int CastSocket::DoWriteComplete(int result) { |
WriteRequest& request = write_queue_.front(); |
scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer; |
io_buffer->DidConsume(result); |
- if (io_buffer->BytesRemaining() == 0) // Message fully sent |
+ if (io_buffer->BytesRemaining() == 0) { // Message fully sent |
SetWriteState(WRITE_STATE_DO_CALLBACK); |
- else |
+ } else { |
SetWriteState(WRITE_STATE_WRITE); |
+ } |
return net::OK; |
} |
@@ -824,24 +828,14 @@ void CastSocket::DoReadLoop(int result) { |
int CastSocket::DoRead() { |
SetReadState(READ_STATE_READ_COMPLETE); |
- // Figure out whether to read header or body, and the remaining bytes. |
- uint32 num_bytes_to_read = 0; |
- if (header_read_buffer_->RemainingCapacity() > 0) { |
- current_read_buffer_ = header_read_buffer_; |
- num_bytes_to_read = header_read_buffer_->RemainingCapacity(); |
- CHECK_LE(num_bytes_to_read, MessageHeader::header_size()); |
- } else { |
- DCHECK_GT(current_message_size_, 0U); |
- num_bytes_to_read = current_message_size_ - body_read_buffer_->offset(); |
- current_read_buffer_ = body_read_buffer_; |
- CHECK_LE(num_bytes_to_read, MessageHeader::max_message_size()); |
- } |
- CHECK_GT(num_bytes_to_read, 0U); |
+ |
+ // Determine how many bytes need to be read. |
+ size_t num_bytes_to_read = framer_->BytesRequested(); |
// Read up to num_bytes_to_read into |current_read_buffer_|. |
int rv = socket_->Read( |
- current_read_buffer_.get(), |
- num_bytes_to_read, |
+ read_buffer_.get(), |
+ base::checked_cast<uint32>(num_bytes_to_read), |
base::Bind(&CastSocket::DoReadLoop, base::Unretained(this))); |
logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_READ, rv); |
@@ -849,10 +843,8 @@ int CastSocket::DoRead() { |
} |
int CastSocket::DoReadComplete(int result) { |
- VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result |
- << " header offset = " |
- << header_read_buffer_->offset() |
- << " body offset = " << body_read_buffer_->offset(); |
+ VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result; |
+ |
if (result <= 0) { // 0 means EOF: the peer closed the socket |
VLOG_WITH_CONNECTION(1) << "Read error, peer closed the socket"; |
SetErrorState(CHANNEL_ERROR_SOCKET_ERROR); |
@@ -860,44 +852,25 @@ int CastSocket::DoReadComplete(int result) { |
return result == 0 ? net::ERR_FAILED : result; |
} |
- // Some data was read. Move the offset in the current buffer forward. |
- CHECK_LE(current_read_buffer_->offset() + result, |
- current_read_buffer_->capacity()); |
- current_read_buffer_->set_offset(current_read_buffer_->offset() + result); |
- |
- if (current_read_buffer_.get() == header_read_buffer_.get() && |
- current_read_buffer_->RemainingCapacity() == 0) { |
- // A full header is read, process the contents. |
- if (!ProcessHeader()) { |
- SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE); |
- SetReadState(READ_STATE_ERROR); |
- } else { |
- // Processed header, now read the body. |
- SetReadState(READ_STATE_READ); |
- } |
- } else if (current_read_buffer_.get() == body_read_buffer_.get() && |
- static_cast<uint32>(current_read_buffer_->offset()) == |
- current_message_size_) { |
- // Store a copy of current_message_size_ since it will be reset by |
- // ProcessBody(). |
- uint32 message_size = current_message_size_; |
- // Full body is read, process the contents. |
- if (ProcessBody()) { |
- logger_->LogSocketEventForMessage( |
- channel_id_, |
- proto::MESSAGE_READ, |
- current_message_->namespace_(), |
- base::StringPrintf("Message size: %u", message_size)); |
- SetReadState(READ_STATE_DO_CALLBACK); |
- } else { |
- SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE); |
- SetReadState(READ_STATE_ERROR); |
- } |
+ size_t message_size; |
+ DCHECK(current_message_.get() == NULL); |
+ current_message_ = framer_->Ingest(result, &message_size, &error_state_); |
+ if (current_message_.get()) { |
+ DCHECK_EQ(error_state_, CHANNEL_ERROR_NONE); |
+ DCHECK_GT(message_size, static_cast<size_t>(0)); |
+ logger_->LogSocketEventForMessage( |
+ channel_id_, |
+ proto::MESSAGE_READ, |
+ current_message_->namespace_(), |
+ base::StringPrintf("Message size: %zu", message_size)); |
+ SetReadState(READ_STATE_DO_CALLBACK); |
+ } else if (error_state_ != CHANNEL_ERROR_NONE) { |
+ DCHECK(current_message_.get() == NULL); |
+ SetReadState(READ_STATE_ERROR); |
} else { |
- // Have not received full header or full body yet; keep reading. |
+ DCHECK(current_message_.get() == NULL); |
SetReadState(READ_STATE_READ); |
} |
- |
return net::OK; |
} |
@@ -909,17 +882,19 @@ int CastSocket::DoReadCallback() { |
challenge_reply_.reset(new CastMessage(message)); |
logger_->LogSocketEvent(channel_id_, proto::RECEIVED_CHALLENGE_REPLY); |
PostTaskToStartConnectLoop(net::OK); |
+ current_message_.reset(); |
return net::OK; |
} else { |
SetReadState(READ_STATE_ERROR); |
SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE); |
+ current_message_.reset(); |
return net::ERR_INVALID_RESPONSE; |
} |
} |
MessageInfo message_info; |
if (!CastMessageToMessageInfo(message, &message_info)) { |
- current_message_->Clear(); |
+ current_message_.reset(); |
SetReadState(READ_STATE_ERROR); |
SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE); |
return net::ERR_INVALID_RESPONSE; |
@@ -930,7 +905,7 @@ int CastSocket::DoReadCallback() { |
message.namespace_(), |
std::string()); |
delegate_->OnMessage(this, message_info); |
- current_message_->Clear(); |
+ current_message_.reset(); |
return net::OK; |
} |
@@ -940,49 +915,7 @@ int CastSocket::DoReadError(int result) { |
return net::ERR_FAILED; |
} |
-bool CastSocket::ProcessHeader() { |
- CHECK_EQ(static_cast<uint32>(header_read_buffer_->offset()), |
- MessageHeader::header_size()); |
- MessageHeader header; |
- MessageHeader::ReadFromIOBuffer(header_read_buffer_.get(), &header); |
- if (header.message_size > MessageHeader::max_message_size()) |
- return false; |
- VLOG_WITH_CONNECTION(2) << "Parsed header { message_size: " |
- << header.message_size << " }"; |
- current_message_size_ = header.message_size; |
- return true; |
-} |
- |
-bool CastSocket::ProcessBody() { |
- CHECK_EQ(static_cast<uint32>(body_read_buffer_->offset()), |
- current_message_size_); |
- if (!current_message_->ParseFromArray( |
- body_read_buffer_->StartOfBuffer(), current_message_size_)) { |
- return false; |
- } |
- current_message_size_ = 0; |
- header_read_buffer_->set_offset(0); |
- body_read_buffer_->set_offset(0); |
- current_read_buffer_ = header_read_buffer_; |
- return true; |
-} |
- |
-// static |
-bool CastSocket::Serialize(const CastMessage& message_proto, |
- std::string* message_data) { |
- DCHECK(message_data); |
- message_proto.SerializeToString(message_data); |
- size_t message_size = message_data->size(); |
- if (message_size > MessageHeader::max_message_size()) { |
- message_data->clear(); |
- return false; |
- } |
- CastSocket::MessageHeader header; |
- header.SetMessageSize(message_size); |
- header.PrependToString(message_data); |
- return true; |
-} |
void CastSocket::CloseWithError() { |
DCHECK(CalledOnValidThread()); |
@@ -1043,50 +976,15 @@ void CastSocket::SetWriteState(WriteState write_state) { |
} |
} |
-CastSocket::MessageHeader::MessageHeader() : message_size(0) { } |
- |
-void CastSocket::MessageHeader::SetMessageSize(size_t size) { |
- DCHECK_LT(size, static_cast<size_t>(kuint32max)); |
- DCHECK_GT(size, 0U); |
- message_size = size; |
-} |
- |
-// TODO(mfoltz): Investigate replacing header serialization with base::Pickle, |
-// if bit-for-bit compatible. |
-void CastSocket::MessageHeader::PrependToString(std::string* str) { |
- MessageHeader output = *this; |
- output.message_size = base::HostToNet32(message_size); |
- size_t header_size = base::checked_cast<size_t, uint32>( |
- MessageHeader::header_size()); |
- scoped_ptr<char, base::FreeDeleter> char_array( |
- static_cast<char*>(malloc(header_size))); |
- memcpy(char_array.get(), &output, header_size); |
- str->insert(0, char_array.get(), header_size); |
-} |
- |
-// TODO(mfoltz): Investigate replacing header deserialization with base::Pickle, |
-// if bit-for-bit compatible. |
-void CastSocket::MessageHeader::ReadFromIOBuffer( |
- net::GrowableIOBuffer* buffer, MessageHeader* header) { |
- uint32 message_size; |
- size_t header_size = base::checked_cast<size_t, uint32>( |
- MessageHeader::header_size()); |
- memcpy(&message_size, buffer->StartOfBuffer(), header_size); |
- header->message_size = base::NetToHost32(message_size); |
-} |
- |
-std::string CastSocket::MessageHeader::ToString() { |
- return "{message_size: " + base::UintToString(message_size) + "}"; |
-} |
- |
CastSocket::WriteRequest::WriteRequest(const net::CompletionCallback& callback) |
: callback(callback) { } |
bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) { |
DCHECK(!io_buffer.get()); |
std::string message_data; |
- if (!Serialize(message_proto, &message_data)) |
+ if (!MessageFramer::Serialize(message_proto, &message_data)) { |
return false; |
+ } |
message_namespace = message_proto.namespace_(); |
io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(message_data), |
message_data.size()); |
@@ -1094,7 +992,6 @@ bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) { |
} |
CastSocket::WriteRequest::~WriteRequest() { } |
- |
} // namespace cast_channel |
} // namespace core_api |
} // namespace extensions |