Chromium Code Reviews| Index: extensions/browser/api/cast_channel/cast_socket.cc |
| diff --git a/extensions/browser/api/cast_channel/cast_socket.cc b/extensions/browser/api/cast_channel/cast_socket.cc |
| index dac6e0d395b84e113b6811912310f1f3c6219ba0..150af5b9515529dafd70b699e12a44014e89f283 100644 |
| --- a/extensions/browser/api/cast_channel/cast_socket.cc |
| +++ b/extensions/browser/api/cast_channel/cast_socket.cc |
| @@ -47,7 +47,10 @@ namespace { |
| // after 9 failed probes. So the total idle time before close is 10 * |
| // kTcpKeepAliveDelaySecs. |
| const int kTcpKeepAliveDelaySecs = 10; |
| - |
| +// Size of a CastSocket header payload. |
| +const size_t kHeaderSizeBytes = sizeof(int32); |
| +// Maximum byte count for a CastSocket message. |
| +const size_t kMaxMessageSizeBytes = 65536; |
| } // namespace |
| namespace extensions { |
| @@ -189,13 +192,12 @@ CastSocket::CastSocket(const std::string& owner_extension_id, |
| ip_endpoint_(ip_endpoint), |
| channel_auth_(channel_auth), |
| delegate_(delegate), |
| - current_message_size_(0), |
| - current_message_(new CastMessage()), |
| net_log_(net_log), |
| logger_(logger), |
| connect_timeout_(timeout), |
| connect_timeout_timer_(new base::OneShotTimer<CastSocket>), |
| is_canceled_(false), |
| + current_message_(new CastMessage), |
| connect_state_(CONN_STATE_NONE), |
| write_state_(WRITE_STATE_NONE), |
| read_state_(READ_STATE_NONE), |
| @@ -207,12 +209,10 @@ CastSocket::CastSocket(const std::string& owner_extension_id, |
| net_log_source_.type = net::NetLog::SOURCE_SOCKET; |
| net_log_source_.id = net_log_->NextID(); |
| - // Reuse these buffers for each message. |
| - header_read_buffer_ = new net::GrowableIOBuffer(); |
| - header_read_buffer_->SetCapacity(MessageHeader::header_size()); |
| - body_read_buffer_ = new net::GrowableIOBuffer(); |
| - body_read_buffer_->SetCapacity(MessageHeader::max_message_size()); |
| - current_read_buffer_ = header_read_buffer_; |
| + // Buffer is reused across messages. |
| + read_buffer_ = new net::GrowableIOBuffer(); |
| + read_buffer_->SetCapacity(kMaxMessageSizeBytes); |
| + framer_.reset(new PacketFramer(read_buffer_)); |
| } |
| CastSocket::~CastSocket() { |
| @@ -272,16 +272,18 @@ bool CastSocket::ExtractPeerCert(std::string* cert) { |
| DCHECK(cert); |
| DCHECK(peer_cert_.empty()); |
| net::SSLInfo ssl_info; |
| - if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get()) |
| + if (!socket_->GetSSLInfo(&ssl_info) || !ssl_info.cert.get()) { |
| return false; |
| + } |
| logger_->LogSocketEvent(channel_id_, proto::SSL_INFO_OBTAINED); |
| bool result = net::X509Certificate::GetDEREncoded( |
| ssl_info.cert->os_cert_handle(), cert); |
| - if (result) |
| + if (result) { |
| VLOG_WITH_CONNECTION(1) << "Successfully extracted peer certificate: " |
| << *cert; |
| + } |
| logger_->LogSocketEventWithRv( |
| channel_id_, proto::DER_ENCODED_CERT_OBTAIN, result ? 1 : 0); |
| @@ -488,8 +490,9 @@ void CastSocket::DoAuthChallengeSendWriteComplete(int result) { |
| int CastSocket::DoAuthChallengeSendComplete(int result) { |
| VLOG_WITH_CONNECTION(1) << "DoAuthChallengeSendComplete: " << result; |
| - if (result < 0) |
| + if (result < 0) { |
| return result; |
| + } |
| SetConnectState(CONN_STATE_AUTH_CHALLENGE_REPLY_COMPLETE); |
| // Post a task to start read loop so that DoReadLoop is not nested inside |
| @@ -502,10 +505,12 @@ int CastSocket::DoAuthChallengeSendComplete(int result) { |
| int CastSocket::DoAuthChallengeReplyComplete(int result) { |
| VLOG_WITH_CONNECTION(1) << "DoAuthChallengeReplyComplete: " << result; |
| - if (result < 0) |
| + if (result < 0) { |
| return result; |
| - if (!VerifyChallengeReply()) |
| + } |
| + if (!VerifyChallengeReply()) { |
| return net::ERR_FAILED; |
| + } |
| VLOG_WITH_CONNECTION(1) << "Auth challenge verification succeeded"; |
| return net::OK; |
| } |
| @@ -666,12 +671,14 @@ void CastSocket::DoWriteLoop(int result) { |
| // If write loop is done because the queue is empty then set write |
| // state to NONE |
| - if (write_queue_.empty()) |
| + if (write_queue_.empty()) { |
| SetWriteState(WRITE_STATE_NONE); |
| + } |
| // Write loop is done - if the result is ERR_FAILED then close with error. |
| - if (rv == net::ERR_FAILED) |
| + if (rv == net::ERR_FAILED) { |
| CloseWithError(); |
| + } |
| } |
| int CastSocket::DoWrite() { |
| @@ -705,10 +712,11 @@ int CastSocket::DoWriteComplete(int result) { |
| WriteRequest& request = write_queue_.front(); |
| scoped_refptr<net::DrainableIOBuffer> io_buffer = request.io_buffer; |
| io_buffer->DidConsume(result); |
| - if (io_buffer->BytesRemaining() == 0) // Message fully sent |
| + if (io_buffer->BytesRemaining() == 0) { // Message fully sent |
| SetWriteState(WRITE_STATE_DO_CALLBACK); |
| - else |
| + } else { |
| SetWriteState(WRITE_STATE_WRITE); |
| + } |
| return net::OK; |
| } |
| @@ -824,23 +832,13 @@ void CastSocket::DoReadLoop(int result) { |
| int CastSocket::DoRead() { |
| SetReadState(READ_STATE_READ_COMPLETE); |
| - // Figure out whether to read header or body, and the remaining bytes. |
| - uint32 num_bytes_to_read = 0; |
| - if (header_read_buffer_->RemainingCapacity() > 0) { |
| - current_read_buffer_ = header_read_buffer_; |
| - num_bytes_to_read = header_read_buffer_->RemainingCapacity(); |
| - CHECK_LE(num_bytes_to_read, MessageHeader::header_size()); |
| - } else { |
| - DCHECK_GT(current_message_size_, 0U); |
| - num_bytes_to_read = current_message_size_ - body_read_buffer_->offset(); |
| - current_read_buffer_ = body_read_buffer_; |
| - CHECK_LE(num_bytes_to_read, MessageHeader::max_message_size()); |
| - } |
| - CHECK_GT(num_bytes_to_read, 0U); |
| + |
| + // Determine how many bytes need to be read. |
| + uint32 num_bytes_to_read = framer_->BytesRequested(); |
|
mark a. foltz
2014/08/26 20:37:02
Slightly prefer that this be a size_t and then is
Kevin M
2014/08/27 01:14:03
Done.
|
| // Read up to num_bytes_to_read into |current_read_buffer_|. |
| int rv = socket_->Read( |
| - current_read_buffer_.get(), |
| + read_buffer_.get(), |
| num_bytes_to_read, |
| base::Bind(&CastSocket::DoReadLoop, base::Unretained(this))); |
| logger_->LogSocketEventWithRv(channel_id_, proto::SOCKET_READ, rv); |
| @@ -849,10 +847,8 @@ int CastSocket::DoRead() { |
| } |
| int CastSocket::DoReadComplete(int result) { |
| - VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result |
| - << " header offset = " |
| - << header_read_buffer_->offset() |
| - << " body offset = " << body_read_buffer_->offset(); |
| + VLOG_WITH_CONNECTION(2) << "DoReadComplete result = " << result; |
| + |
| if (result <= 0) { // 0 means EOF: the peer closed the socket |
| VLOG_WITH_CONNECTION(1) << "Read error, peer closed the socket"; |
| SetErrorState(CHANNEL_ERROR_SOCKET_ERROR); |
| @@ -860,44 +856,20 @@ int CastSocket::DoReadComplete(int result) { |
| return result == 0 ? net::ERR_FAILED : result; |
| } |
| - // Some data was read. Move the offset in the current buffer forward. |
| - CHECK_LE(current_read_buffer_->offset() + result, |
| - current_read_buffer_->capacity()); |
| - current_read_buffer_->set_offset(current_read_buffer_->offset() + result); |
| - |
| - if (current_read_buffer_.get() == header_read_buffer_.get() && |
| - current_read_buffer_->RemainingCapacity() == 0) { |
| - // A full header is read, process the contents. |
| - if (!ProcessHeader()) { |
| - SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE); |
| - SetReadState(READ_STATE_ERROR); |
| - } else { |
| - // Processed header, now read the body. |
| - SetReadState(READ_STATE_READ); |
| - } |
| - } else if (current_read_buffer_.get() == body_read_buffer_.get() && |
| - static_cast<uint32>(current_read_buffer_->offset()) == |
| - current_message_size_) { |
| - // Store a copy of current_message_size_ since it will be reset by |
| - // ProcessBody(). |
| - uint32 message_size = current_message_size_; |
| - // Full body is read, process the contents. |
| - if (ProcessBody()) { |
| - logger_->LogSocketEventForMessage( |
| - channel_id_, |
| - proto::MESSAGE_READ, |
| - current_message_->namespace_(), |
| - base::StringPrintf("Message size: %u", message_size)); |
| - SetReadState(READ_STATE_DO_CALLBACK); |
| - } else { |
| - SetErrorState(CHANNEL_ERROR_INVALID_MESSAGE); |
| - SetReadState(READ_STATE_ERROR); |
| - } |
| + size_t message_size; |
| + if (framer_->Ingest( |
| + result, current_message_.get(), &message_size, &error_state_)) { |
| + logger_->LogSocketEventForMessage( |
| + channel_id_, |
| + proto::MESSAGE_READ, |
| + current_message_->namespace_(), |
| + base::StringPrintf("Message size: %zu", message_size)); |
| + SetReadState(READ_STATE_DO_CALLBACK); |
| + } else if (error_state_ != CHANNEL_ERROR_NONE) { |
| + SetReadState(READ_STATE_ERROR); |
| } else { |
| - // Have not received full header or full body yet; keep reading. |
| SetReadState(READ_STATE_READ); |
| } |
| - |
| return net::OK; |
| } |
| @@ -940,49 +912,7 @@ int CastSocket::DoReadError(int result) { |
| return net::ERR_FAILED; |
| } |
| -bool CastSocket::ProcessHeader() { |
| - CHECK_EQ(static_cast<uint32>(header_read_buffer_->offset()), |
| - MessageHeader::header_size()); |
| - MessageHeader header; |
| - MessageHeader::ReadFromIOBuffer(header_read_buffer_.get(), &header); |
| - if (header.message_size > MessageHeader::max_message_size()) |
| - return false; |
| - VLOG_WITH_CONNECTION(2) << "Parsed header { message_size: " |
| - << header.message_size << " }"; |
| - current_message_size_ = header.message_size; |
| - return true; |
| -} |
| - |
| -bool CastSocket::ProcessBody() { |
| - CHECK_EQ(static_cast<uint32>(body_read_buffer_->offset()), |
| - current_message_size_); |
| - if (!current_message_->ParseFromArray( |
| - body_read_buffer_->StartOfBuffer(), current_message_size_)) { |
| - return false; |
| - } |
| - current_message_size_ = 0; |
| - header_read_buffer_->set_offset(0); |
| - body_read_buffer_->set_offset(0); |
| - current_read_buffer_ = header_read_buffer_; |
| - return true; |
| -} |
| - |
| -// static |
| -bool CastSocket::Serialize(const CastMessage& message_proto, |
| - std::string* message_data) { |
| - DCHECK(message_data); |
| - message_proto.SerializeToString(message_data); |
| - size_t message_size = message_data->size(); |
| - if (message_size > MessageHeader::max_message_size()) { |
| - message_data->clear(); |
| - return false; |
| - } |
| - CastSocket::MessageHeader header; |
| - header.SetMessageSize(message_size); |
| - header.PrependToString(message_data); |
| - return true; |
| -} |
| void CastSocket::CloseWithError() { |
| DCHECK(CalledOnValidThread()); |
| @@ -1043,9 +973,18 @@ void CastSocket::SetWriteState(WriteState write_state) { |
| } |
| } |
| -CastSocket::MessageHeader::MessageHeader() : message_size(0) { } |
| +PacketFramer::PacketFramer(scoped_refptr<net::GrowableIOBuffer> buffer) |
| + : buffer_(buffer) { |
| + Reset(); |
| +} |
| -void CastSocket::MessageHeader::SetMessageSize(size_t size) { |
| +PacketFramer::~PacketFramer() { |
| +} |
| + |
| +PacketFramer::MessageHeader::MessageHeader() : message_size(0) { |
| +} |
| + |
| +void PacketFramer::MessageHeader::SetMessageSize(size_t size) { |
| DCHECK_LT(size, static_cast<size_t>(kuint32max)); |
| DCHECK_GT(size, 0U); |
| message_size = size; |
| @@ -1053,7 +992,7 @@ void CastSocket::MessageHeader::SetMessageSize(size_t size) { |
| // TODO(mfoltz): Investigate replacing header serialization with base::Pickle, |
| // if bit-for-bit compatible. |
| -void CastSocket::MessageHeader::PrependToString(std::string* str) { |
| +void PacketFramer::MessageHeader::PrependToString(std::string* str) { |
| MessageHeader output = *this; |
| output.message_size = base::HostToNet32(message_size); |
| size_t header_size = base::checked_cast<size_t, uint32>( |
| @@ -1066,16 +1005,24 @@ void CastSocket::MessageHeader::PrependToString(std::string* str) { |
| // TODO(mfoltz): Investigate replacing header deserialization with base::Pickle, |
| // if bit-for-bit compatible. |
| -void CastSocket::MessageHeader::ReadFromIOBuffer( |
| - net::GrowableIOBuffer* buffer, MessageHeader* header) { |
| +void PacketFramer::MessageHeader::Deserialize(char* data, |
| + MessageHeader* header) { |
| uint32 message_size; |
| - size_t header_size = base::checked_cast<size_t, uint32>( |
| - MessageHeader::header_size()); |
| - memcpy(&message_size, buffer->StartOfBuffer(), header_size); |
| + memcpy(&message_size, data, kHeaderSizeBytes); |
| header->message_size = base::NetToHost32(message_size); |
| } |
| -std::string CastSocket::MessageHeader::ToString() { |
| +// static |
| +uint32 PacketFramer::MessageHeader::header_size() { |
| + return kHeaderSizeBytes; |
| +} |
| + |
| +// static |
| +uint32 PacketFramer::MessageHeader::max_message_size() { |
| + return kMaxMessageSizeBytes; |
| +} |
| + |
| +std::string PacketFramer::MessageHeader::ToString() { |
| return "{message_size: " + base::UintToString(message_size) + "}"; |
| } |
| @@ -1085,8 +1032,9 @@ CastSocket::WriteRequest::WriteRequest(const net::CompletionCallback& callback) |
| bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) { |
| DCHECK(!io_buffer.get()); |
| std::string message_data; |
| - if (!Serialize(message_proto, &message_data)) |
| + if (!PacketFramer::Serialize(message_proto, &message_data)) { |
| return false; |
| + } |
| message_namespace = message_proto.namespace_(); |
| io_buffer = new net::DrainableIOBuffer(new net::StringIOBuffer(message_data), |
| message_data.size()); |
| @@ -1095,6 +1043,87 @@ bool CastSocket::WriteRequest::SetContent(const CastMessage& message_proto) { |
| CastSocket::WriteRequest::~WriteRequest() { } |
| +// static |
| +bool PacketFramer::Serialize(const CastMessage& message_proto, |
| + std::string* message_data) { |
| + DCHECK(message_data); |
| + message_proto.SerializeToString(message_data); |
| + size_t message_size = message_data->size(); |
| + if (message_size > kMaxMessageSizeBytes) { |
| + message_data->clear(); |
| + return false; |
| + } |
| + MessageHeader header; |
| + header.SetMessageSize(message_size); |
| + header.PrependToString(message_data); |
| + return true; |
| +} |
| + |
| +size_t PacketFramer::BytesRequested() { |
| + if (current_element_ == HEADER) { |
| + size_t bytes_left = kHeaderSizeBytes - packet_bytes_read_; |
| + VLOG(2) << "Bytes needed for header: " << bytes_left; |
| + return bytes_left; |
| + } else if (current_element_ == BODY) { |
| + size_t bytes_left = (message_size_ + kHeaderSizeBytes) - packet_bytes_read_; |
| + VLOG(2) << "Bytes needed for body: " << bytes_left; |
| + return bytes_left; |
| + } else { |
| + NOTREACHED() << "Unhandled packet element type."; |
| + return 0; |
| + } |
| +} |
| + |
| +bool PacketFramer::Ingest(uint32 num_bytes, |
| + CastMessage* message, |
|
mark a. foltz
2014/08/26 20:37:02
I feel like the message being read should be owned
Kevin M
2014/08/27 01:14:03
Done.
|
| + size_t* message_length, |
| + ChannelError* error) { |
| + DCHECK_EQ(base::checked_cast<int32>(packet_bytes_read_), buffer_->offset()); |
| + DCHECK(message); |
| + DCHECK(error); |
| + CHECK_LE(num_bytes, BytesRequested()); |
| + |
| + bool was_message_parsed = false; |
| + packet_bytes_read_ += num_bytes; |
| + *error = CHANNEL_ERROR_NONE; |
| + if (current_element_ == HEADER) { |
| + if (BytesRequested() == 0) { |
| + MessageHeader header; |
| + MessageHeader::Deserialize(buffer_.get()->StartOfBuffer(), &header); |
| + if (header.message_size > MessageHeader::max_message_size()) { |
| + VLOG(1) << "Error parsing header (message size too large)."; |
| + *error = CHANNEL_ERROR_INVALID_MESSAGE; |
| + return false; |
| + } |
| + current_element_ = BODY; |
| + message_size_ = header.message_size; |
| + } |
| + } else if (current_element_ == BODY) { |
| + if (BytesRequested() == 0) { |
| + CastMessage parsed_message; |
| + if (!parsed_message.ParseFromArray( |
| + buffer_->StartOfBuffer() + kHeaderSizeBytes, message_size_)) { |
| + VLOG(1) << "Error parsing packet body."; |
| + *error = CHANNEL_ERROR_INVALID_MESSAGE; |
| + return false; |
| + } |
| + parsed_message.Swap(message); |
| + *message_length = message_size_; |
| + was_message_parsed = true; |
| + Reset(); |
| + } |
| + } |
| + |
| + buffer_->set_offset(packet_bytes_read_); |
| + return was_message_parsed; |
| +} |
| + |
| +void PacketFramer::Reset() { |
| + current_element_ = HEADER; |
| + packet_bytes_read_ = 0; |
| + message_size_ = 0; |
| + buffer_->set_offset(0); |
| +} |
| } // namespace cast_channel |
| } // namespace core_api |
| } // namespace extensions |