Chromium Code Reviews| OLD | NEW |
|---|---|
| (Empty) | |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
| 2 // Use of this source code is governed by a BSD-style license that can be | |
| 3 // found in the LICENSE file. | |
| 4 | |
| 5 #include "remoting/protocol/channel_multiplexer.h" | |
| 6 | |
| 7 #include <string.h> | |
| 8 | |
| 9 #include "base/bind.h" | |
| 10 #include "base/callback.h" | |
| 11 #include "base/location.h" | |
| 12 #include "base/stl_util.h" | |
| 13 #include "net/base/net_errors.h" | |
| 14 #include "net/socket/stream_socket.h" | |
| 15 #include "remoting/protocol/util.h" | |
| 16 | |
| 17 namespace remoting { | |
| 18 namespace protocol { | |
| 19 | |
| 20 namespace { | |
| 21 const int kChannelIdUnknown = -1; | |
| 22 const int kMaxPacketSize = 1024; | |
| 23 | |
| 24 class PendingPacket { | |
| 25 public: | |
| 26 PendingPacket(scoped_ptr<MuxPacket> packet, const base::Closure& done_task) | |
| 27 : packet(packet.Pass()), | |
| 28 done_task(done_task), | |
| 29 pos(0U) { | |
| 30 } | |
| 31 ~PendingPacket() { | |
| 32 done_task.Run(); | |
| 33 } | |
| 34 | |
| 35 bool is_empty() { return pos >= packet->data().size(); } | |
| 36 | |
| 37 int Read(char* buffer, size_t size) { | |
| 38 size = std::min(size, packet->data().size() - pos); | |
| 39 memcpy(buffer, packet->data().data() + pos, size); | |
| 40 pos += size; | |
| 41 return size; | |
| 42 } | |
| 43 | |
| 44 private: | |
| 45 scoped_ptr<MuxPacket> packet; | |
| 46 base::Closure done_task; | |
| 47 size_t pos; | |
| 48 | |
| 49 DISALLOW_COPY_AND_ASSIGN(PendingPacket); | |
| 50 }; | |
| 51 | |
| 52 } // namespace | |
| 53 | |
| 54 const char ChannelMultiplexer::kMuxChannelName[] = "mux"; | |
| 55 | |
| 56 struct ChannelMultiplexer::PendingChannel { | |
| 57 PendingChannel(const std::string& name, | |
| 58 const StreamChannelCallback& callback) | |
| 59 : name(name), callback(callback) { | |
| 60 } | |
| 61 std::string name; | |
| 62 StreamChannelCallback callback; | |
| 63 }; | |
| 64 | |
| 65 class ChannelMultiplexer::MuxChannel { | |
| 66 public: | |
| 67 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name, int id); | |
| 68 ~MuxChannel(); | |
| 69 | |
| 70 const std::string& name() { return name_; } | |
| 71 int id() { return id_; } | |
|
Wez
2012/08/06 23:14:07
nit: local_id?
Sergey Ulanov
2012/08/07 20:12:22
renamed to send id
| |
| 72 int remote_id() { return remote_id_; } | |
|
Wez
2012/08/06 23:14:07
nit: Or received_id() and send_id()?
Sergey Ulanov
2012/08/07 20:12:22
Done.
| |
| 73 void set_remote_id(int id) { remote_id_ = id; } | |
| 74 | |
| 75 // Called by ChannelMultiplexer. | |
| 76 scoped_ptr<net::StreamSocket> CreateSocket(); | |
| 77 void OnIncomingPacket(scoped_ptr<MuxPacket> packet, | |
| 78 const base::Closure& done_task); | |
| 79 void OnWriteFailed(); | |
| 80 | |
| 81 // Called by MuxSocket. | |
| 82 void OnSocketDestroyed(); | |
| 83 void DoWrite(scoped_ptr<MuxPacket> packet, const base::Closure& done_task); | |
| 84 int DoRead(net::IOBuffer* buffer, int buffer_len); | |
| 85 | |
| 86 private: | |
| 87 ChannelMultiplexer* multiplexer_; | |
| 88 std::string name_; | |
| 89 int id_; | |
| 90 bool id_sent_; | |
| 91 int remote_id_; | |
| 92 MuxSocket* socket_; | |
| 93 std::list<PendingPacket*> pending_packets_; | |
| 94 | |
| 95 DISALLOW_COPY_AND_ASSIGN(MuxChannel); | |
| 96 }; | |
| 97 | |
| 98 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, | |
| 99 public base::NonThreadSafe, | |
| 100 public base::SupportsWeakPtr<MuxSocket> { | |
| 101 public: | |
| 102 MuxSocket(MuxChannel* channel); | |
| 103 ~MuxSocket(); | |
| 104 | |
| 105 void OnWriteComplete(); | |
| 106 void OnWriteFailed(); | |
| 107 void OnPacketReceived(); | |
| 108 | |
| 109 // net::StreamSocket interface. | |
| 110 virtual int Read(net::IOBuffer* buffer, int buffer_len, | |
| 111 const net::CompletionCallback& callback) OVERRIDE; | |
| 112 virtual int Write(net::IOBuffer* buffer, int buffer_len, | |
| 113 const net::CompletionCallback& callback) OVERRIDE; | |
| 114 | |
| 115 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { | |
| 116 NOTIMPLEMENTED(); | |
| 117 return false; | |
| 118 } | |
| 119 virtual bool SetSendBufferSize(int32 size) OVERRIDE { | |
| 120 NOTIMPLEMENTED(); | |
| 121 return false; | |
| 122 } | |
| 123 | |
| 124 virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { | |
| 125 NOTIMPLEMENTED(); | |
| 126 return net::ERR_FAILED; | |
| 127 } | |
| 128 virtual void Disconnect() OVERRIDE { | |
| 129 NOTIMPLEMENTED(); | |
| 130 } | |
| 131 virtual bool IsConnected() const OVERRIDE { | |
| 132 NOTIMPLEMENTED(); | |
| 133 return true; | |
| 134 } | |
| 135 virtual bool IsConnectedAndIdle() const OVERRIDE { | |
| 136 NOTIMPLEMENTED(); | |
| 137 return false; | |
| 138 } | |
| 139 virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { | |
| 140 NOTIMPLEMENTED(); | |
| 141 return net::ERR_FAILED; | |
| 142 } | |
| 143 virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { | |
| 144 NOTIMPLEMENTED(); | |
| 145 return net::ERR_FAILED; | |
| 146 } | |
| 147 virtual const net::BoundNetLog& NetLog() const OVERRIDE { | |
| 148 NOTIMPLEMENTED(); | |
| 149 return net_log_; | |
| 150 } | |
| 151 virtual void SetSubresourceSpeculation() OVERRIDE { | |
| 152 NOTIMPLEMENTED(); | |
| 153 } | |
| 154 virtual void SetOmniboxSpeculation() OVERRIDE { | |
| 155 NOTIMPLEMENTED(); | |
| 156 } | |
| 157 virtual bool WasEverUsed() const OVERRIDE { | |
| 158 return true; | |
| 159 } | |
| 160 virtual bool UsingTCPFastOpen() const OVERRIDE { | |
| 161 return false; | |
| 162 } | |
| 163 virtual int64 NumBytesRead() const OVERRIDE { | |
| 164 NOTIMPLEMENTED(); | |
| 165 return 0; | |
| 166 } | |
| 167 virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { | |
| 168 NOTIMPLEMENTED(); | |
| 169 return base::TimeDelta(); | |
| 170 } | |
| 171 virtual bool WasNpnNegotiated() const OVERRIDE { | |
| 172 return false; | |
| 173 } | |
| 174 virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { | |
| 175 return net::kProtoUnknown; | |
| 176 } | |
| 177 virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { | |
| 178 NOTIMPLEMENTED(); | |
| 179 return false; | |
| 180 } | |
| 181 | |
| 182 private: | |
| 183 MuxChannel* channel_; | |
| 184 | |
| 185 net::CompletionCallback read_callback_; | |
| 186 scoped_refptr<net::IOBuffer> read_buffer_; | |
| 187 int read_buffer_size_; | |
| 188 | |
| 189 bool write_pending_; | |
| 190 int write_result_; | |
| 191 net::CompletionCallback write_callback_; | |
| 192 bool write_failed_; | |
| 193 | |
| 194 net::BoundNetLog net_log_; | |
| 195 | |
| 196 DISALLOW_COPY_AND_ASSIGN(MuxSocket); | |
| 197 }; | |
| 198 | |
| 199 | |
| 200 ChannelMultiplexer::MuxChannel::MuxChannel( | |
| 201 ChannelMultiplexer* multiplexer, | |
| 202 const std::string& name, | |
| 203 int id) | |
| 204 : multiplexer_(multiplexer), | |
| 205 name_(name), | |
| 206 id_(id), | |
| 207 id_sent_(false), | |
| 208 remote_id_(kChannelIdUnknown), | |
| 209 socket_(NULL) { | |
| 210 } | |
| 211 | |
| 212 ChannelMultiplexer::MuxChannel::~MuxChannel() { | |
| 213 // Socket must be destroyed before the channel. | |
| 214 DCHECK(!socket_); | |
| 215 STLDeleteElements(&pending_packets_); | |
| 216 } | |
| 217 | |
| 218 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() { | |
| 219 DCHECK(!socket_); // Can't create more than one socket per channel. | |
| 220 scoped_ptr<MuxSocket> result(new MuxSocket(this)); | |
| 221 socket_ = result.get(); | |
| 222 return result.PassAs<net::StreamSocket>(); | |
| 223 } | |
| 224 | |
| 225 void ChannelMultiplexer::MuxChannel::OnIncomingPacket( | |
| 226 scoped_ptr<MuxPacket> packet, | |
| 227 const base::Closure& done_task) { | |
| 228 DCHECK_EQ(packet->channel_id(), remote_id_); | |
| 229 if (packet->data().size() > 0) { | |
| 230 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); | |
| 231 if (socket_) { | |
| 232 // Notify the socket that we have more data. | |
| 233 socket_->OnPacketReceived(); | |
| 234 } | |
| 235 } | |
| 236 } | |
| 237 | |
| 238 void ChannelMultiplexer::MuxChannel::OnWriteFailed() { | |
| 239 if (socket_) | |
| 240 socket_->OnWriteFailed(); | |
| 241 } | |
| 242 | |
| 243 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { | |
| 244 DCHECK(socket_); | |
| 245 socket_ = NULL; | |
|
Wez
2012/08/06 23:14:07
nit: Remove the MuxChannel from the list of active
Sergey Ulanov
2012/08/07 20:12:22
We need to keep this object in order to preserve n
| |
| 246 } | |
| 247 | |
| 248 void ChannelMultiplexer::MuxChannel::DoWrite( | |
| 249 scoped_ptr<MuxPacket> packet, | |
| 250 const base::Closure& done_task) { | |
| 251 packet->set_channel_id(id_); | |
| 252 if (!id_sent_) { | |
| 253 packet->set_channel_name(name_); | |
| 254 id_sent_ = true; | |
| 255 } | |
| 256 multiplexer_->DoWrite(packet.Pass(), done_task); | |
| 257 } | |
| 258 | |
| 259 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer, | |
| 260 int buffer_len) { | |
| 261 int pos = 0; | |
| 262 while (buffer_len > 0 && !pending_packets_.empty()) { | |
| 263 int result = pending_packets_.front()->Read( | |
| 264 buffer->data() + pos, buffer_len); | |
| 265 DCHECK_LE(result, buffer_len); | |
| 266 pos += result; | |
| 267 buffer_len -= pos; | |
| 268 while (!pending_packets_.empty() && | |
| 269 pending_packets_.front()->is_empty()) { | |
|
Wez
2012/08/06 23:14:07
Do you need a while loop here? Under what circumst
Sergey Ulanov
2012/08/07 20:12:22
Done.
| |
| 270 delete pending_packets_.front(); | |
| 271 pending_packets_.erase(pending_packets_.begin()); | |
| 272 } | |
| 273 } | |
| 274 return pos; | |
| 275 } | |
| 276 | |
| 277 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel) | |
| 278 : channel_(channel), | |
| 279 read_buffer_size_(0), | |
| 280 write_pending_(false), | |
| 281 write_result_(0), | |
| 282 write_failed_(false) { | |
| 283 } | |
| 284 | |
| 285 ChannelMultiplexer::MuxSocket::~MuxSocket() { | |
|
Wez
2012/08/06 23:14:07
Thread check before OnSocketDestroyed()?
Sergey Ulanov
2012/08/07 20:12:22
The class inherits from NonThreadSafe and ~NonThre
Wez
2012/08/07 21:44:44
Not until after OnSocketDestroyed() has already be
| |
| 286 channel_->OnSocketDestroyed(); | |
| 287 } | |
| 288 | |
| 289 int ChannelMultiplexer::MuxSocket::Read( | |
| 290 net::IOBuffer* buffer, int buffer_len, | |
| 291 const net::CompletionCallback& callback) { | |
| 292 DCHECK(CalledOnValidThread()); | |
| 293 DCHECK(read_callback_.is_null()); | |
| 294 | |
| 295 int result = channel_->DoRead(buffer, buffer_len); | |
| 296 if (result == 0) { | |
| 297 read_buffer_ = buffer; | |
| 298 read_buffer_size_ = buffer_len; | |
| 299 read_callback_ = callback; | |
| 300 return net::ERR_IO_PENDING; | |
| 301 } | |
| 302 return result; | |
| 303 } | |
| 304 | |
| 305 int ChannelMultiplexer::MuxSocket::Write( | |
| 306 net::IOBuffer* buffer, int buffer_len, | |
| 307 const net::CompletionCallback& callback) { | |
| 308 DCHECK(CalledOnValidThread()); | |
| 309 | |
| 310 if (write_failed_) | |
| 311 return net::ERR_FAILED; | |
| 312 | |
| 313 scoped_ptr<MuxPacket> packet(new MuxPacket()); | |
| 314 size_t size = std::min(kMaxPacketSize, buffer_len); | |
| 315 packet->mutable_data()->assign(buffer->data(), size); | |
| 316 | |
| 317 write_pending_ = true; | |
| 318 channel_->DoWrite(packet.Pass(), base::Bind( | |
| 319 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); | |
|
Wez
2012/08/06 23:14:07
nit: It's a little confusing that we set write_pen
Sergey Ulanov
2012/08/07 20:12:22
The semantics used by BufferedSocketWriter is easi
| |
| 320 | |
| 321 if (write_pending_) { | |
| 322 DCHECK(write_callback_.is_null()); | |
| 323 write_callback_ = callback; | |
| 324 write_result_ = size; | |
| 325 return net::ERR_IO_PENDING; | |
| 326 } | |
| 327 | |
| 328 return write_failed_ ? net::ERR_FAILED : size; | |
| 329 } | |
| 330 | |
| 331 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { | |
| 332 write_pending_ = false; | |
| 333 if (!write_callback_.is_null()) { | |
| 334 net::CompletionCallback cb; | |
| 335 std::swap(cb, write_callback_); | |
| 336 cb.Run(write_result_); | |
| 337 } | |
| 338 } | |
| 339 | |
| 340 void ChannelMultiplexer::MuxSocket::OnWriteFailed() { | |
| 341 if (!write_callback_.is_null()) { | |
| 342 net::CompletionCallback cb; | |
| 343 std::swap(cb, write_callback_); | |
| 344 cb.Run(net::ERR_FAILED); | |
|
Wez
2012/08/06 23:14:07
Don't you need to do this after setting write_fail
Sergey Ulanov
2012/08/07 20:12:22
Good point. Fixed.
| |
| 345 } | |
| 346 write_failed_ = true; | |
|
Wez
2012/08/06 23:14:07
nit: Do you actually need to keep a write_failed f
Sergey Ulanov
2012/08/07 20:12:22
Done.
| |
| 347 } | |
| 348 | |
| 349 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { | |
| 350 if (!read_callback_.is_null()) { | |
| 351 int result = channel_->DoRead(read_buffer_, read_buffer_size_); | |
| 352 read_buffer_ = NULL; | |
| 353 DCHECK_GT(result, 0); | |
| 354 net::CompletionCallback cb; | |
| 355 std::swap(cb, read_callback_); | |
| 356 cb.Run(result); | |
| 357 } | |
| 358 } | |
| 359 | |
| 360 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory) | |
| 361 : base_channel_failed_(false), | |
| 362 next_channel_id_(0) { | |
| 363 factory->CreateStreamChannel( | |
| 364 kMuxChannelName, base::Bind(&ChannelMultiplexer::OnBaseChannelReady, | |
| 365 base::Unretained(this))); | |
| 366 } | |
| 367 | |
| 368 ChannelMultiplexer::~ChannelMultiplexer() { | |
| 369 STLDeleteValues(&channels_); | |
| 370 } | |
| 371 | |
| 372 void ChannelMultiplexer::CreateStreamChannel( | |
| 373 const std::string& name, | |
| 374 const StreamChannelCallback& callback) { | |
| 375 if (base_channel_.get()) { | |
| 376 // Already have |base_channel_|. Create new multiplexed channel | |
| 377 // synchronously. | |
| 378 callback.Run(GetOrCreateChannel(name)->CreateSocket()); | |
| 379 } else if (base_channel_failed_) { | |
| 380 // Fail synchronously if we failed to create |base_channel_|. | |
| 381 callback.Run(scoped_ptr<net::StreamSocket>()); | |
| 382 } else { | |
| 383 // Still waiting for the |base_channel_|. | |
| 384 pending_channels_.push_back(PendingChannel(name, callback)); | |
| 385 } | |
| 386 } | |
| 387 | |
| 388 void ChannelMultiplexer::CreateDatagramChannel( | |
| 389 const std::string& name, | |
| 390 const DatagramChannelCallback& callback) { | |
| 391 NOTIMPLEMENTED(); | |
| 392 callback.Run(scoped_ptr<net::Socket>()); | |
| 393 } | |
| 394 | |
| 395 void ChannelMultiplexer::OnBaseChannelReady( | |
| 396 scoped_ptr<net::StreamSocket> socket) { | |
| 397 base_channel_ = socket.Pass(); | |
| 398 | |
| 399 if (!base_channel_.get()) { | |
| 400 base_channel_failed_ = true; | |
| 401 | |
| 402 // Notify all callers that we can't create any channels. | |
| 403 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | |
| 404 it != pending_channels_.end(); ++it) { | |
| 405 it->callback.Run(scoped_ptr<net::StreamSocket>()); | |
| 406 } | |
| 407 pending_channels_.clear(); | |
| 408 return; | |
| 409 } | |
| 410 | |
| 411 // Initialize reader and writer. | |
| 412 reader_.Init(base_channel_.get(), | |
| 413 base::Bind(&ChannelMultiplexer::OnIncomingPacket, | |
| 414 base::Unretained(this))); | |
| 415 writer_.Init(base_channel_.get(), | |
| 416 base::Bind(&ChannelMultiplexer::OnWriteFailed, | |
| 417 base::Unretained(this))); | |
| 418 | |
| 419 // Now create all pending channels. | |
| 420 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | |
| 421 it != pending_channels_.end(); ++it) { | |
| 422 it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket()); | |
| 423 } | |
| 424 pending_channels_.clear(); | |
| 425 } | |
| 426 | |
| 427 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( | |
| 428 const std::string& name) { | |
| 429 // Check if we already have a channel with the requested name. | |
| 430 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); | |
| 431 if (it != channels_.end()) | |
| 432 return it->second; | |
| 433 | |
| 434 // Create a new channel if we haven't found existing one. | |
| 435 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); | |
| 436 ++next_channel_id_; | |
| 437 channels_[channel->name()] = channel; | |
| 438 return channel; | |
| 439 } | |
| 440 | |
| 441 | |
| 442 void ChannelMultiplexer::OnWriteFailed(int error) { | |
| 443 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); | |
| 444 it != channels_.end(); ++it) { | |
| 445 it->second->OnWriteFailed(); | |
|
Wez
2012/08/06 23:14:07
This triggers caller-supplied callbacks, which mig
Sergey Ulanov
2012/08/07 20:12:22
Fixed this method to handle the case when multiple
| |
| 446 } | |
| 447 } | |
| 448 | |
| 449 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MuxPacket> packet, | |
| 450 const base::Closure& done_task) { | |
| 451 if (!packet->has_channel_id()) { | |
| 452 LOG(ERROR) << "Received packet without channel_id."; | |
| 453 done_task.Run(); | |
| 454 return; | |
| 455 } | |
| 456 | |
| 457 int remote_id = packet->channel_id(); | |
| 458 MuxChannel* channel = NULL; | |
| 459 std::map<int, MuxChannel*>::iterator it = | |
| 460 channels_by_remote_id_.find(remote_id); | |
| 461 if (it != channels_by_remote_id_.end()) { | |
| 462 channel = it->second; | |
| 463 } else { | |
| 464 // This is a new |channel_id| we haven't seen before. Look it up by name. | |
| 465 if (!packet->has_channel_name()) { | |
| 466 LOG(ERROR) << "Received packet with unknown channel_id and " | |
| 467 "without channel_name."; | |
| 468 done_task.Run(); | |
| 469 return; | |
| 470 } | |
| 471 channel = GetOrCreateChannel(packet->channel_name()); | |
|
Wez
2012/08/06 23:14:07
This has the disadvantage that if a peer sends lot
Sergey Ulanov
2012/08/07 20:12:22
MessageReader doesn't try to read from the channel
| |
| 472 channel->set_remote_id(remote_id); | |
| 473 channels_by_remote_id_[remote_id] = channel; | |
| 474 } | |
| 475 | |
| 476 channel->OnIncomingPacket(packet.Pass(), done_task); | |
| 477 } | |
| 478 | |
| 479 void ChannelMultiplexer::DoWrite(scoped_ptr<MuxPacket> packet, | |
| 480 const base::Closure& done_task) { | |
|
Wez
2012/08/06 23:14:07
nit: Indentation.
Sergey Ulanov
2012/08/07 20:12:22
Done.
| |
| 481 writer_.Write(SerializeAndFrameMessage(*packet), done_task); | |
| 482 } | |
| 483 | |
| 484 } // namespace protocol | |
| 485 } // namespace remoting | |
| OLD | NEW |