Chromium Code Reviews| Index: remoting/protocol/channel_multiplexer.cc |
| diff --git a/remoting/protocol/channel_multiplexer.cc b/remoting/protocol/channel_multiplexer.cc |
| new file mode 100644 |
| index 0000000000000000000000000000000000000000..ad6c6860f0af5ae95e8c310f64ba450114ab7a43 |
| --- /dev/null |
| +++ b/remoting/protocol/channel_multiplexer.cc |
| @@ -0,0 +1,485 @@ |
| +// Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| +// Use of this source code is governed by a BSD-style license that can be |
| +// found in the LICENSE file. |
| + |
| +#include "remoting/protocol/channel_multiplexer.h" |
| + |
| +#include <string.h> |
| + |
| +#include "base/bind.h" |
| +#include "base/callback.h" |
| +#include "base/location.h" |
| +#include "base/stl_util.h" |
| +#include "net/base/net_errors.h" |
| +#include "net/socket/stream_socket.h" |
| +#include "remoting/protocol/util.h" |
| + |
| +namespace remoting { |
| +namespace protocol { |
| + |
| +namespace { |
| +const int kChannelIdUnknown = -1; |
| +const int kMaxPacketSize = 1024; |
| + |
| +class PendingPacket { |
| + public: |
| + PendingPacket(scoped_ptr<MuxPacket> packet, const base::Closure& done_task) |
| + : packet(packet.Pass()), |
| + done_task(done_task), |
| + pos(0U) { |
| + } |
| + ~PendingPacket() { |
| + done_task.Run(); |
| + } |
| + |
| + bool is_empty() { return pos >= packet->data().size(); } |
| + |
| + int Read(char* buffer, size_t size) { |
| + size = std::min(size, packet->data().size() - pos); |
| + memcpy(buffer, packet->data().data() + pos, size); |
| + pos += size; |
| + return size; |
| + } |
| + |
| + private: |
| + scoped_ptr<MuxPacket> packet; |
| + base::Closure done_task; |
| + size_t pos; |
| + |
| + DISALLOW_COPY_AND_ASSIGN(PendingPacket); |
| +}; |
| + |
| +} // namespace |
| + |
| +const char ChannelMultiplexer::kMuxChannelName[] = "mux"; |
| + |
| +struct ChannelMultiplexer::PendingChannel { |
| + PendingChannel(const std::string& name, |
| + const StreamChannelCallback& callback) |
| + : name(name), callback(callback) { |
| + } |
| + std::string name; |
| + StreamChannelCallback callback; |
| +}; |
| + |
| +class ChannelMultiplexer::MuxChannel { |
| + public: |
| + MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name, int id); |
| + ~MuxChannel(); |
| + |
| + const std::string& name() { return name_; } |
| + int id() { return id_; } |
|
Wez
2012/08/06 23:14:07
nit: local_id?
Sergey Ulanov
2012/08/07 20:12:22
renamed to send id
|
| + int remote_id() { return remote_id_; } |
|
Wez
2012/08/06 23:14:07
nit: Or received_id() and send_id()?
Sergey Ulanov
2012/08/07 20:12:22
Done.
|
| + void set_remote_id(int id) { remote_id_ = id; } |
| + |
| + // Called by ChannelMultiplexer. |
| + scoped_ptr<net::StreamSocket> CreateSocket(); |
| + void OnIncomingPacket(scoped_ptr<MuxPacket> packet, |
| + const base::Closure& done_task); |
| + void OnWriteFailed(); |
| + |
| + // Called by MuxSocket. |
| + void OnSocketDestroyed(); |
| + void DoWrite(scoped_ptr<MuxPacket> packet, const base::Closure& done_task); |
| + int DoRead(net::IOBuffer* buffer, int buffer_len); |
| + |
| + private: |
| + ChannelMultiplexer* multiplexer_; |
| + std::string name_; |
| + int id_; |
| + bool id_sent_; |
| + int remote_id_; |
| + MuxSocket* socket_; |
| + std::list<PendingPacket*> pending_packets_; |
| + |
| + DISALLOW_COPY_AND_ASSIGN(MuxChannel); |
| +}; |
| + |
| +class ChannelMultiplexer::MuxSocket : public net::StreamSocket, |
| + public base::NonThreadSafe, |
| + public base::SupportsWeakPtr<MuxSocket> { |
| + public: |
| + MuxSocket(MuxChannel* channel); |
| + ~MuxSocket(); |
| + |
| + void OnWriteComplete(); |
| + void OnWriteFailed(); |
| + void OnPacketReceived(); |
| + |
| + // net::StreamSocket interface. |
| + virtual int Read(net::IOBuffer* buffer, int buffer_len, |
| + const net::CompletionCallback& callback) OVERRIDE; |
| + virtual int Write(net::IOBuffer* buffer, int buffer_len, |
| + const net::CompletionCallback& callback) OVERRIDE; |
| + |
| + virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return false; |
| + } |
| + virtual bool SetSendBufferSize(int32 size) OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return false; |
| + } |
| + |
| + virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return net::ERR_FAILED; |
| + } |
| + virtual void Disconnect() OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + } |
| + virtual bool IsConnected() const OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return true; |
| + } |
| + virtual bool IsConnectedAndIdle() const OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return false; |
| + } |
| + virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return net::ERR_FAILED; |
| + } |
| + virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return net::ERR_FAILED; |
| + } |
| + virtual const net::BoundNetLog& NetLog() const OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return net_log_; |
| + } |
| + virtual void SetSubresourceSpeculation() OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + } |
| + virtual void SetOmniboxSpeculation() OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + } |
| + virtual bool WasEverUsed() const OVERRIDE { |
| + return true; |
| + } |
| + virtual bool UsingTCPFastOpen() const OVERRIDE { |
| + return false; |
| + } |
| + virtual int64 NumBytesRead() const OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return 0; |
| + } |
| + virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return base::TimeDelta(); |
| + } |
| + virtual bool WasNpnNegotiated() const OVERRIDE { |
| + return false; |
| + } |
| + virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { |
| + return net::kProtoUnknown; |
| + } |
| + virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { |
| + NOTIMPLEMENTED(); |
| + return false; |
| + } |
| + |
| + private: |
| + MuxChannel* channel_; |
| + |
| + net::CompletionCallback read_callback_; |
| + scoped_refptr<net::IOBuffer> read_buffer_; |
| + int read_buffer_size_; |
| + |
| + bool write_pending_; |
| + int write_result_; |
| + net::CompletionCallback write_callback_; |
| + bool write_failed_; |
| + |
| + net::BoundNetLog net_log_; |
| + |
| + DISALLOW_COPY_AND_ASSIGN(MuxSocket); |
| +}; |
| + |
| + |
| +ChannelMultiplexer::MuxChannel::MuxChannel( |
| + ChannelMultiplexer* multiplexer, |
| + const std::string& name, |
| + int id) |
| + : multiplexer_(multiplexer), |
| + name_(name), |
| + id_(id), |
| + id_sent_(false), |
| + remote_id_(kChannelIdUnknown), |
| + socket_(NULL) { |
| +} |
| + |
| +ChannelMultiplexer::MuxChannel::~MuxChannel() { |
| + // Socket must be destroyed before the channel. |
| + DCHECK(!socket_); |
| + STLDeleteElements(&pending_packets_); |
| +} |
| + |
| +scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() { |
| + DCHECK(!socket_); // Can't create more than one socket per channel. |
| + scoped_ptr<MuxSocket> result(new MuxSocket(this)); |
| + socket_ = result.get(); |
| + return result.PassAs<net::StreamSocket>(); |
| +} |
| + |
| +void ChannelMultiplexer::MuxChannel::OnIncomingPacket( |
| + scoped_ptr<MuxPacket> packet, |
| + const base::Closure& done_task) { |
| + DCHECK_EQ(packet->channel_id(), remote_id_); |
| + if (packet->data().size() > 0) { |
| + pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); |
| + if (socket_) { |
| + // Notify the socket that we have more data. |
| + socket_->OnPacketReceived(); |
| + } |
| + } |
| +} |
| + |
| +void ChannelMultiplexer::MuxChannel::OnWriteFailed() { |
| + if (socket_) |
| + socket_->OnWriteFailed(); |
| +} |
| + |
| +void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { |
| + DCHECK(socket_); |
| + socket_ = NULL; |
|
Wez
2012/08/06 23:14:07
nit: Remove the MuxChannel from the list of active
Sergey Ulanov
2012/08/07 20:12:22
We need to keep this object in order to preserve n
|
| +} |
| + |
| +void ChannelMultiplexer::MuxChannel::DoWrite( |
| + scoped_ptr<MuxPacket> packet, |
| + const base::Closure& done_task) { |
| + packet->set_channel_id(id_); |
| + if (!id_sent_) { |
| + packet->set_channel_name(name_); |
| + id_sent_ = true; |
| + } |
| + multiplexer_->DoWrite(packet.Pass(), done_task); |
| +} |
| + |
| +int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer, |
| + int buffer_len) { |
| + int pos = 0; |
| + while (buffer_len > 0 && !pending_packets_.empty()) { |
| + int result = pending_packets_.front()->Read( |
| + buffer->data() + pos, buffer_len); |
| + DCHECK_LE(result, buffer_len); |
| + pos += result; |
| + buffer_len -= pos; |
| + while (!pending_packets_.empty() && |
| + pending_packets_.front()->is_empty()) { |
|
Wez
2012/08/06 23:14:07
Do you need a while loop here? Under what circumst
Sergey Ulanov
2012/08/07 20:12:22
Done.
|
| + delete pending_packets_.front(); |
| + pending_packets_.erase(pending_packets_.begin()); |
| + } |
| + } |
| + return pos; |
| +} |
| + |
| +ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel) |
| + : channel_(channel), |
| + read_buffer_size_(0), |
| + write_pending_(false), |
| + write_result_(0), |
| + write_failed_(false) { |
| +} |
| + |
| +ChannelMultiplexer::MuxSocket::~MuxSocket() { |
|
Wez
2012/08/06 23:14:07
Thread check before OnSocketDestroyed()?
Sergey Ulanov
2012/08/07 20:12:22
The class inherits from NonThreadSafe and ~NonThre
Wez
2012/08/07 21:44:44
Not until after OnSocketDestroyed() has already be
|
| + channel_->OnSocketDestroyed(); |
| +} |
| + |
| +int ChannelMultiplexer::MuxSocket::Read( |
| + net::IOBuffer* buffer, int buffer_len, |
| + const net::CompletionCallback& callback) { |
| + DCHECK(CalledOnValidThread()); |
| + DCHECK(read_callback_.is_null()); |
| + |
| + int result = channel_->DoRead(buffer, buffer_len); |
| + if (result == 0) { |
| + read_buffer_ = buffer; |
| + read_buffer_size_ = buffer_len; |
| + read_callback_ = callback; |
| + return net::ERR_IO_PENDING; |
| + } |
| + return result; |
| +} |
| + |
| +int ChannelMultiplexer::MuxSocket::Write( |
| + net::IOBuffer* buffer, int buffer_len, |
| + const net::CompletionCallback& callback) { |
| + DCHECK(CalledOnValidThread()); |
| + |
| + if (write_failed_) |
| + return net::ERR_FAILED; |
| + |
| + scoped_ptr<MuxPacket> packet(new MuxPacket()); |
| + size_t size = std::min(kMaxPacketSize, buffer_len); |
| + packet->mutable_data()->assign(buffer->data(), size); |
| + |
| + write_pending_ = true; |
| + channel_->DoWrite(packet.Pass(), base::Bind( |
| + &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); |
|
Wez
2012/08/06 23:14:07
nit: It's a little confusing that we set write_pen
Sergey Ulanov
2012/08/07 20:12:22
The semantics used by BufferedSocketWriter is easi
|
| + |
| + if (write_pending_) { |
| + DCHECK(write_callback_.is_null()); |
| + write_callback_ = callback; |
| + write_result_ = size; |
| + return net::ERR_IO_PENDING; |
| + } |
| + |
| + return write_failed_ ? net::ERR_FAILED : size; |
| +} |
| + |
| +void ChannelMultiplexer::MuxSocket::OnWriteComplete() { |
| + write_pending_ = false; |
| + if (!write_callback_.is_null()) { |
| + net::CompletionCallback cb; |
| + std::swap(cb, write_callback_); |
| + cb.Run(write_result_); |
| + } |
| +} |
| + |
| +void ChannelMultiplexer::MuxSocket::OnWriteFailed() { |
| + if (!write_callback_.is_null()) { |
| + net::CompletionCallback cb; |
| + std::swap(cb, write_callback_); |
| + cb.Run(net::ERR_FAILED); |
|
Wez
2012/08/06 23:14:07
Don't you need to do this after setting write_fail
Sergey Ulanov
2012/08/07 20:12:22
Good point. Fixed.
|
| + } |
| + write_failed_ = true; |
|
Wez
2012/08/06 23:14:07
nit: Do you actually need to keep a write_failed f
Sergey Ulanov
2012/08/07 20:12:22
Done.
|
| +} |
| + |
| +void ChannelMultiplexer::MuxSocket::OnPacketReceived() { |
| + if (!read_callback_.is_null()) { |
| + int result = channel_->DoRead(read_buffer_, read_buffer_size_); |
| + read_buffer_ = NULL; |
| + DCHECK_GT(result, 0); |
| + net::CompletionCallback cb; |
| + std::swap(cb, read_callback_); |
| + cb.Run(result); |
| + } |
| +} |
| + |
| +ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory) |
| + : base_channel_failed_(false), |
| + next_channel_id_(0) { |
| + factory->CreateStreamChannel( |
| + kMuxChannelName, base::Bind(&ChannelMultiplexer::OnBaseChannelReady, |
| + base::Unretained(this))); |
| +} |
| + |
| +ChannelMultiplexer::~ChannelMultiplexer() { |
| + STLDeleteValues(&channels_); |
| +} |
| + |
| +void ChannelMultiplexer::CreateStreamChannel( |
| + const std::string& name, |
| + const StreamChannelCallback& callback) { |
| + if (base_channel_.get()) { |
| + // Already have |base_channel_|. Create new multiplexed channel |
| + // synchronously. |
| + callback.Run(GetOrCreateChannel(name)->CreateSocket()); |
| + } else if (base_channel_failed_) { |
| + // Fail synchronously if we failed to create |base_channel_|. |
| + callback.Run(scoped_ptr<net::StreamSocket>()); |
| + } else { |
| + // Still waiting for the |base_channel_|. |
| + pending_channels_.push_back(PendingChannel(name, callback)); |
| + } |
| +} |
| + |
| +void ChannelMultiplexer::CreateDatagramChannel( |
| + const std::string& name, |
| + const DatagramChannelCallback& callback) { |
| + NOTIMPLEMENTED(); |
| + callback.Run(scoped_ptr<net::Socket>()); |
| +} |
| + |
| +void ChannelMultiplexer::OnBaseChannelReady( |
| + scoped_ptr<net::StreamSocket> socket) { |
| + base_channel_ = socket.Pass(); |
| + |
| + if (!base_channel_.get()) { |
| + base_channel_failed_ = true; |
| + |
| + // Notify all callers that we can't create any channels. |
| + for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); |
| + it != pending_channels_.end(); ++it) { |
| + it->callback.Run(scoped_ptr<net::StreamSocket>()); |
| + } |
| + pending_channels_.clear(); |
| + return; |
| + } |
| + |
| + // Initialize reader and writer. |
| + reader_.Init(base_channel_.get(), |
| + base::Bind(&ChannelMultiplexer::OnIncomingPacket, |
| + base::Unretained(this))); |
| + writer_.Init(base_channel_.get(), |
| + base::Bind(&ChannelMultiplexer::OnWriteFailed, |
| + base::Unretained(this))); |
| + |
| + // Now create all pending channels. |
| + for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); |
| + it != pending_channels_.end(); ++it) { |
| + it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket()); |
| + } |
| + pending_channels_.clear(); |
| +} |
| + |
| +ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( |
| + const std::string& name) { |
| + // Check if we already have a channel with the requested name. |
| + std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); |
| + if (it != channels_.end()) |
| + return it->second; |
| + |
| + // Create a new channel if we haven't found existing one. |
| + MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); |
| + ++next_channel_id_; |
| + channels_[channel->name()] = channel; |
| + return channel; |
| +} |
| + |
| + |
| +void ChannelMultiplexer::OnWriteFailed(int error) { |
| + for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); |
| + it != channels_.end(); ++it) { |
| + it->second->OnWriteFailed(); |
|
Wez
2012/08/06 23:14:07
This triggers caller-supplied callbacks, which mig
Sergey Ulanov
2012/08/07 20:12:22
Fixed this method to handle the case when multiple
|
| + } |
| +} |
| + |
| +void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MuxPacket> packet, |
| + const base::Closure& done_task) { |
| + if (!packet->has_channel_id()) { |
| + LOG(ERROR) << "Received packet without channel_id."; |
| + done_task.Run(); |
| + return; |
| + } |
| + |
| + int remote_id = packet->channel_id(); |
| + MuxChannel* channel = NULL; |
| + std::map<int, MuxChannel*>::iterator it = |
| + channels_by_remote_id_.find(remote_id); |
| + if (it != channels_by_remote_id_.end()) { |
| + channel = it->second; |
| + } else { |
| + // This is a new |channel_id| we haven't seen before. Look it up by name. |
| + if (!packet->has_channel_name()) { |
| + LOG(ERROR) << "Received packet with unknown channel_id and " |
| + "without channel_name."; |
| + done_task.Run(); |
| + return; |
| + } |
| + channel = GetOrCreateChannel(packet->channel_name()); |
|
Wez
2012/08/06 23:14:07
This has the disadvantage that if a peer sends lot
Sergey Ulanov
2012/08/07 20:12:22
MessageReader doesn't try to read from the channel
|
| + channel->set_remote_id(remote_id); |
| + channels_by_remote_id_[remote_id] = channel; |
| + } |
| + |
| + channel->OnIncomingPacket(packet.Pass(), done_task); |
| +} |
| + |
| +void ChannelMultiplexer::DoWrite(scoped_ptr<MuxPacket> packet, |
| + const base::Closure& done_task) { |
|
Wez
2012/08/06 23:14:07
nit: Indentation.
Sergey Ulanov
2012/08/07 20:12:22
Done.
|
| + writer_.Write(SerializeAndFrameMessage(*packet), done_task); |
| +} |
| + |
| +} // namespace protocol |
| +} // namespace remoting |