| Index: extensions/browser/api/cast_channel/cast_socket.cc
|
| diff --git a/extensions/browser/api/cast_channel/cast_socket.cc b/extensions/browser/api/cast_channel/cast_socket.cc
|
| index dac6e0d395b84e113b6811912310f1f3c6219ba0..617a53730d1025f4e4e664f173620b6c5d40711e 100644
|
| --- a/extensions/browser/api/cast_channel/cast_socket.cc
|
| +++ b/extensions/browser/api/cast_channel/cast_socket.cc
|
| @@ -18,6 +18,7 @@
|
| #include "extensions/browser/api/cast_channel/cast_auth_util.h"
|
| #include "extensions/browser/api/cast_channel/cast_channel.pb.h"
|
| #include "extensions/browser/api/cast_channel/cast_message_util.h"
|
| +#include "extensions/browser/api/cast_channel/cast_socket_framer.h"
|
| #include "extensions/browser/api/cast_channel/logger.h"
|
| #include "extensions/browser/api/cast_channel/logger_util.h"
|
| #include "net/base/address_list.h"
|
| @@ -47,7 +48,6 @@ namespace {
|
| // after 9 failed probes. So the total idle time before close is 10 *
|
| // kTcpKeepAliveDelaySecs.
|
| const int kTcpKeepAliveDelaySecs = 10;
|
| -
|
| } // namespace
|
|
|
| namespace extensions {
|
| @@ -189,8 +189,6 @@ CastSocket::CastSocket(const std::string& owner_extension_id,
|
| ip_endpoint_(ip_endpoint),
|
| channel_auth_(channel_auth),
|
| delegate_(delegate),
|
| - current_message_size_(0),
|
| - current_message_(new CastMessage()),
|
| net_log_(net_log),
|
| logger_(logger),
|
| connect_timeout_(timeout),
|
| @@ -207,12 +205,10 @@ CastSocket::CastSocket(const std::string& owner_extension_id,
|
| net_log_source_.type = net::NetLog::SOURCE_SOCKET;
|
| net_log_source_.id = net_log_->NextID();
|
|
|
| - // Reuse these buffers for each message.
|
| - header_read_buffer_ = new net::GrowableIOBuffer();
|
| - header_read_buffer_->SetCapacity(MessageHeader::header_size());
|
| - body_read_buffer_ = new net::GrowableIOBuffer();
|
| - body_read_buffer_->SetCapacity(MessageHeader::max_message_size());
|
| - current_read_buffer_ = header_read_buffer_;
|
| + // Buffer is reused across messages.
|
| + read_buffer_ = new net::GrowableIOBuffer();
|
| + read_buffer_->SetCapacity(kMaxMessageSizeBytes);
|
| + framer_.reset(new MessageFramer(read_buffer_));
|
| }
|
|
|
| CastSocket::~CastSocket() {
|
| @@ -272,16 +268,18 @@ bool CastSocket::ExtractPeerCert(std::string* cert) {
|
| DCHECK(cert);
|
| DCHECK(peer_cert_.empty());
|
| net::SSLInfo ssl_info;
|
| - if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get())
|
| + if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get()) {
|
| return false;
|
| + }
|
|
|
| logger_->LogSocketEvent(channel_id_, proto::SSL_INFO_OBTAINED);
|
|
|
| bool result = net::X509Certificate::GetDEREncoded(
|
| ssl_info.cert->os_cert_handle(), cert);
|
| - if (result)
|
| + if (result) {
|
| VLOG_WITH_CONNECTION(1) << "Successfully extracted peer certificate: "
|
| << *cert;
|
| + }
|
|
|
| logger_->LogSocketEventWithRv(
|
| channel_id_, proto::DER_ENCODED_CERT_OBTAIN, result ? 1 : 0);
|
| @@ -488,8 +486,9 @@ void CastSocket::DoAuthChallengeSendWriteComplete(int result) {
|
|
|
| int CastSocket::DoAuthChallengeSendComplete(int result) {
|
| VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSendComplete: " << result;
|
| - if (result < 0)
|
| + if (result < 0) {
|
| return result;
|
| + }
|
| SetConnectState(CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE);
|
|
|
| // Post a task to start read loop so that DoReadLoop is not nested inside
|
| @@ -502,10 +501,12 @@ int CastSocket::DoAuthChallengeSendComplete(int result) {
|
|
|
| int CastSocket::DoAuthChallengeReplyComplete(int result) {
|
| VLOG_WITH_CONNECTION(1) << "DoAuthChallengeReplyComplete: " << result;
|
| - if (result < 0)
|
| + if (result < 0) {
|
| return result;
|
| - if (!VerifyChallengeReply())
|
| + }
|
| + if (!VerifyChallengeReply()) {
|
| return net::ERR_FAILED;
|
| + }
|
| VLOG_WITH_CONNECTION(1) << "Auth challenge verification succeeded";
|
| return net::OK;
|
| }
|
| @@ -666,12 +667,14 @@ void CastSocket::DoWriteLoop(int result) {
|
|
|
| // If write loop is done because the queue is empty then set write
|
| // state to NONE
|
| - if (write_queue_.empty())
|
| + if (write_queue_.empty()) {
|
| SetWriteState(WRITE_STATE_NONE);
|
| + }
|
|
|
| // Write loop is done - if the result is ERR_FAILED then close with error.
|
| - if (rv == net::ERR_FAILED)
|
| + if (rv == net::ERR_FAILED) {
|
| CloseWithError();
|
| + }
|
| }
|
|
|
| int CastSocket::DoWrite() {
|
| @@ -705,10 +708,11 @@ int CastSocket::DoWriteComplete(int result) {
|
| WriteRequest& request = write_queue_.front();
|
| scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer;
|
| io_buffer->DidConsume(result);
|
| - if (io_buffer->BytesRemaining() == 0) // Message fully sent
|
| + if (io_buffer->BytesRemaining() == 0) { // Message fully sent
|
| SetWriteState(WRITE_STATE_DO_CALLBACK);
|
| - else
|
| + } else {
|
| SetWriteState(WRITE_STATE_WRITE);
|
| + }
|
|
|
| return net::OK;
|
| }
|
| @@ -824,24 +828,14 @@ void CastSocket::DoReadLoop(int result) {
|
|
|
| int CastSocket::DoRead() {
|
| SetReadState(READ_STATE_READ_COMPLETE);
|
| - // Figure out whether to read header or body, and the remaining bytes.
|
| - uint32 num_bytes_to_read = 0;
|
| - if (header_read_buffer_->RemainingCapacity() > 0) {
|
| - current_read_buffer_ = header_read_buffer_;
|
| - num_bytes_to_read = header_read_buffer_->RemainingCapacity();
|
| - CHECK_LE(num_bytes_to_read, MessageHeader::header_size());
|
| - } else {
|
| - DCHECK_GT(current_message_size_, 0U);
|
| - num_bytes_to_read = current_message_size_ - body_read_buffer_->offset();
|
| - current_read_buffer_ = body_read_buffer_;
|
| - CHECK_LE(num_bytes_to_read, MessageHeader::max_message_size());
|
| - }
|
| - CHECK_GT(num_bytes_to_read, 0U);
|
| +
|
| + // Determine how many bytes need to be read.
|
| + size_t num_bytes_to_read = framer_->BytesRequested();
|
|
|
| // Read up to num_bytes_to_read into |current_read_buffer_|.
|
| int rv = socket_->Read(
|
| - current_read_buffer_.get(),
|
| - num_bytes_to_read,
|
| + read_buffer_.get(),
|
| + base::checked_cast<uint32>(num_bytes_to_read),
|
| base::Bind(&CastSocket::DoReadLoop, base::Unretained(this)));
|
| logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_READ, rv);
|
|
|
| @@ -849,10 +843,8 @@ int CastSocket::DoRead() {
|
| }
|
|
|
| int CastSocket::DoReadComplete(int result) {
|
| - VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result
|
| - << " header offset = "
|
| - << header_read_buffer_->offset()
|
| - << " body offset = " << body_read_buffer_->offset();
|
| + VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result;
|
| +
|
| if (result <= 0) { // 0 means EOF: the peer closed the socket
|
| VLOG_WITH_CONNECTION(1) << "Read error, peer closed the socket";
|
| SetErrorState(CHANNEL_ERROR_SOCKET_ERROR);
|
| @@ -860,44 +852,25 @@ int CastSocket::DoReadComplete(int result) {
|
| return result == 0 ? net::ERR_FAILED : result;
|
| }
|
|
|
| - // Some data was read. Move the offset in the current buffer forward.
|
| - CHECK_LE(current_read_buffer_->offset() + result,
|
| - current_read_buffer_->capacity());
|
| - current_read_buffer_->set_offset(current_read_buffer_->offset() + result);
|
| -
|
| - if (current_read_buffer_.get() == header_read_buffer_.get() &&
|
| - current_read_buffer_->RemainingCapacity() == 0) {
|
| - // A full header is read, process the contents.
|
| - if (!ProcessHeader()) {
|
| - SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
|
| - SetReadState(READ_STATE_ERROR);
|
| - } else {
|
| - // Processed header, now read the body.
|
| - SetReadState(READ_STATE_READ);
|
| - }
|
| - } else if (current_read_buffer_.get() == body_read_buffer_.get() &&
|
| - static_cast<uint32>(current_read_buffer_->offset()) ==
|
| - current_message_size_) {
|
| - // Store a copy of current_message_size_ since it will be reset by
|
| - // ProcessBody().
|
| - uint32 message_size = current_message_size_;
|
| - // Full body is read, process the contents.
|
| - if (ProcessBody()) {
|
| - logger_->LogSocketEventForMessage(
|
| - channel_id_,
|
| - proto::MESSAGE_READ,
|
| - current_message_->namespace_(),
|
| - base::StringPrintf("Message size: %u", message_size));
|
| - SetReadState(READ_STATE_DO_CALLBACK);
|
| - } else {
|
| - SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
|
| - SetReadState(READ_STATE_ERROR);
|
| - }
|
| + size_t message_size;
|
| + DCHECK(current_message_.get() == NULL);
|
| + current_message_ = framer_->Ingest(result, &message_size, &error_state_);
|
| + if (current_message_.get()) {
|
| + DCHECK_EQ(error_state_, CHANNEL_ERROR_NONE);
|
| + DCHECK_GT(message_size, static_cast<size_t>(0));
|
| + logger_->LogSocketEventForMessage(
|
| + channel_id_,
|
| + proto::MESSAGE_READ,
|
| + current_message_->namespace_(),
|
| + base::StringPrintf("Message size: %zu", message_size));
|
| + SetReadState(READ_STATE_DO_CALLBACK);
|
| + } else if (error_state_ != CHANNEL_ERROR_NONE) {
|
| + DCHECK(current_message_.get() == NULL);
|
| + SetReadState(READ_STATE_ERROR);
|
| } else {
|
| - // Have not received full header or full body yet; keep reading.
|
| + DCHECK(current_message_.get() == NULL);
|
| SetReadState(READ_STATE_READ);
|
| }
|
| -
|
| return net::OK;
|
| }
|
|
|
| @@ -909,17 +882,19 @@ int CastSocket::DoReadCallback() {
|
| challenge_reply_.reset(new CastMessage(message));
|
| logger_->LogSocketEvent(channel_id_, proto::RECEIVED_CHALLENGE_REPLY);
|
| PostTaskToStartConnectLoop(net::OK);
|
| + current_message_.reset();
|
| return net::OK;
|
| } else {
|
| SetReadState(READ_STATE_ERROR);
|
| SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
|
| + current_message_.reset();
|
| return net::ERR_INVALID_RESPONSE;
|
| }
|
| }
|
|
|
| MessageInfo message_info;
|
| if (!CastMessageToMessageInfo(message, &message_info)) {
|
| - current_message_->Clear();
|
| + current_message_.reset();
|
| SetReadState(READ_STATE_ERROR);
|
| SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE);
|
| return net::ERR_INVALID_RESPONSE;
|
| @@ -930,7 +905,7 @@ int CastSocket::DoReadCallback() {
|
| message.namespace_(),
|
| std::string());
|
| delegate_->OnMessage(this, message_info);
|
| - current_message_->Clear();
|
| + current_message_.reset();
|
|
|
| return net::OK;
|
| }
|
| @@ -940,49 +915,7 @@ int CastSocket::DoReadError(int result) {
|
| return net::ERR_FAILED;
|
| }
|
|
|
| -bool CastSocket::ProcessHeader() {
|
| - CHECK_EQ(static_cast<uint32>(header_read_buffer_->offset()),
|
| - MessageHeader::header_size());
|
| - MessageHeader header;
|
| - MessageHeader::ReadFromIOBuffer(header_read_buffer_.get(), &header);
|
| - if (header.message_size > MessageHeader::max_message_size())
|
| - return false;
|
|
|
| - VLOG_WITH_CONNECTION(2) << "Parsed header { message_size: "
|
| - << header.message_size << " }";
|
| - current_message_size_ = header.message_size;
|
| - return true;
|
| -}
|
| -
|
| -bool CastSocket::ProcessBody() {
|
| - CHECK_EQ(static_cast<uint32>(body_read_buffer_->offset()),
|
| - current_message_size_);
|
| - if (!current_message_->ParseFromArray(
|
| - body_read_buffer_->StartOfBuffer(), current_message_size_)) {
|
| - return false;
|
| - }
|
| - current_message_size_ = 0;
|
| - header_read_buffer_->set_offset(0);
|
| - body_read_buffer_->set_offset(0);
|
| - current_read_buffer_ = header_read_buffer_;
|
| - return true;
|
| -}
|
| -
|
| -// static
|
| -bool CastSocket::Serialize(const CastMessage& message_proto,
|
| - std::string* message_data) {
|
| - DCHECK(message_data);
|
| - message_proto.SerializeToString(message_data);
|
| - size_t message_size = message_data->size();
|
| - if (message_size > MessageHeader::max_message_size()) {
|
| - message_data->clear();
|
| - return false;
|
| - }
|
| - CastSocket::MessageHeader header;
|
| - header.SetMessageSize(message_size);
|
| - header.PrependToString(message_data);
|
| - return true;
|
| -}
|
|
|
| void CastSocket::CloseWithError() {
|
| DCHECK(CalledOnValidThread());
|
| @@ -1043,50 +976,15 @@ void CastSocket::SetWriteState(WriteState write_state) {
|
| }
|
| }
|
|
|
| -CastSocket::MessageHeader::MessageHeader() : message_size(0) { }
|
| -
|
| -void CastSocket::MessageHeader::SetMessageSize(size_t size) {
|
| - DCHECK_LT(size, static_cast<size_t>(kuint32max));
|
| - DCHECK_GT(size, 0U);
|
| - message_size = size;
|
| -}
|
| -
|
| -// TODO(mfoltz): Investigate replacing header serialization with base::Pickle,
|
| -// if bit-for-bit compatible.
|
| -void CastSocket::MessageHeader::PrependToString(std::string* str) {
|
| - MessageHeader output = *this;
|
| - output.message_size = base::HostToNet32(message_size);
|
| - size_t header_size = base::checked_cast<size_t, uint32>(
|
| - MessageHeader::header_size());
|
| - scoped_ptr<char, base::FreeDeleter> char_array(
|
| - static_cast<char*>(malloc(header_size)));
|
| - memcpy(char_array.get(), &output, header_size);
|
| - str->insert(0, char_array.get(), header_size);
|
| -}
|
| -
|
| -// TODO(mfoltz): Investigate replacing header deserialization with base::Pickle,
|
| -// if bit-for-bit compatible.
|
| -void CastSocket::MessageHeader::ReadFromIOBuffer(
|
| - net::GrowableIOBuffer* buffer, MessageHeader* header) {
|
| - uint32 message_size;
|
| - size_t header_size = base::checked_cast<size_t, uint32>(
|
| - MessageHeader::header_size());
|
| - memcpy(&message_size, buffer->StartOfBuffer(), header_size);
|
| - header->message_size = base::NetToHost32(message_size);
|
| -}
|
| -
|
| -std::string CastSocket::MessageHeader::ToString() {
|
| - return "{message_size: " + base::UintToString(message_size) + "}";
|
| -}
|
| -
|
| CastSocket::WriteRequest::WriteRequest(const net::CompletionCallback& callback)
|
| : callback(callback) { }
|
|
|
| bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) {
|
| DCHECK(!io_buffer.get());
|
| std::string message_data;
|
| - if (!Serialize(message_proto, &message_data))
|
| + if (!MessageFramer::Serialize(message_proto, &message_data)) {
|
| return false;
|
| + }
|
| message_namespace = message_proto.namespace_();
|
| io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(message_data),
|
| message_data.size());
|
| @@ -1094,7 +992,6 @@ bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) {
|
| }
|
|
|
| CastSocket::WriteRequest::~WriteRequest() { }
|
| -
|
| } // namespace cast_channel
|
| } // namespace core_api
|
| } // namespace extensions
|
|
|