Chromium Code Reviews| OLD | NEW |
|---|---|
| (Empty) | |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
| 2 // Use of this source code is governed by a BSD-style license that can be | |
| 3 // found in the LICENSE file. | |
| 4 | |
| 5 #include "remoting/protocol/channel_multiplexer.h" | |
| 6 | |
| 7 #include <string.h> | |
| 8 | |
| 9 #include "base/bind.h" | |
| 10 #include "base/callback.h" | |
| 11 #include "base/location.h" | |
| 12 #include "base/stl_util.h" | |
| 13 #include "net/base/net_errors.h" | |
| 14 #include "net/socket/stream_socket.h" | |
| 15 #include "remoting/protocol/util.h" | |
| 16 | |
| 17 namespace remoting { | |
| 18 namespace protocol { | |
| 19 | |
| 20 namespace { | |
| 21 const int kChannelIdUnknown = -1; | |
| 22 const int kMaxPacketSize = 1024; | |
| 23 | |
| 24 class PendingPacket { | |
| 25 public: | |
| 26 PendingPacket(scoped_ptr<MultiplexPacket> packet, | |
| 27 const base::Closure& done_task) | |
| 28 : packet(packet.Pass()), | |
| 29 done_task(done_task), | |
| 30 pos(0U) { | |
| 31 } | |
| 32 ~PendingPacket() { | |
| 33 done_task.Run(); | |
| 34 } | |
| 35 | |
| 36 bool is_empty() { return pos >= packet->data().size(); } | |
| 37 | |
| 38 int Read(char* buffer, size_t size) { | |
| 39 size = std::min(size, packet->data().size() - pos); | |
| 40 memcpy(buffer, packet->data().data() + pos, size); | |
| 41 pos += size; | |
| 42 return size; | |
| 43 } | |
| 44 | |
| 45 private: | |
| 46 scoped_ptr<MultiplexPacket> packet; | |
| 47 base::Closure done_task; | |
| 48 size_t pos; | |
| 49 | |
| 50 DISALLOW_COPY_AND_ASSIGN(PendingPacket); | |
| 51 }; | |
| 52 | |
| 53 } // namespace | |
| 54 | |
| 55 const char ChannelMultiplexer::kMuxChannelName[] = "mux"; | |
| 56 | |
| 57 struct ChannelMultiplexer::PendingChannel { | |
| 58 PendingChannel(const std::string& name, | |
| 59 const StreamChannelCallback& callback) | |
| 60 : name(name), callback(callback) { | |
| 61 } | |
| 62 std::string name; | |
| 63 StreamChannelCallback callback; | |
| 64 }; | |
| 65 | |
| 66 class ChannelMultiplexer::MuxChannel { | |
| 67 public: | |
| 68 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name, | |
| 69 int send_id); | |
| 70 ~MuxChannel(); | |
| 71 | |
| 72 const std::string& name() { return name_; } | |
| 73 int receive_id() { return receive_id_; } | |
| 74 void set_receive_id(int id) { receive_id_ = id; } | |
| 75 | |
| 76 // Called by ChannelMultiplexer. | |
| 77 scoped_ptr<net::StreamSocket> CreateSocket(); | |
| 78 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, | |
| 79 const base::Closure& done_task); | |
| 80 void OnWriteFailed(); | |
| 81 | |
| 82 // Called by MuxSocket. | |
| 83 void OnSocketDestroyed(); | |
| 84 bool DoWrite(scoped_ptr<MultiplexPacket> packet, | |
| 85 const base::Closure& done_task); | |
| 86 int DoRead(net::IOBuffer* buffer, int buffer_len); | |
| 87 | |
| 88 private: | |
| 89 ChannelMultiplexer* multiplexer_; | |
| 90 std::string name_; | |
| 91 int send_id_; | |
| 92 bool id_sent_; | |
| 93 int receive_id_; | |
| 94 MuxSocket* socket_; | |
| 95 std::list<PendingPacket*> pending_packets_; | |
| 96 | |
| 97 DISALLOW_COPY_AND_ASSIGN(MuxChannel); | |
| 98 }; | |
| 99 | |
| 100 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, | |
| 101 public base::NonThreadSafe, | |
| 102 public base::SupportsWeakPtr<MuxSocket> { | |
| 103 public: | |
| 104 MuxSocket(MuxChannel* channel); | |
| 105 ~MuxSocket(); | |
| 106 | |
| 107 void OnWriteComplete(); | |
| 108 void OnWriteFailed(); | |
| 109 void OnPacketReceived(); | |
| 110 | |
| 111 // net::StreamSocket interface. | |
| 112 virtual int Read(net::IOBuffer* buffer, int buffer_len, | |
| 113 const net::CompletionCallback& callback) OVERRIDE; | |
| 114 virtual int Write(net::IOBuffer* buffer, int buffer_len, | |
| 115 const net::CompletionCallback& callback) OVERRIDE; | |
| 116 | |
| 117 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { | |
| 118 NOTIMPLEMENTED(); | |
| 119 return false; | |
| 120 } | |
| 121 virtual bool SetSendBufferSize(int32 size) OVERRIDE { | |
| 122 NOTIMPLEMENTED(); | |
| 123 return false; | |
| 124 } | |
| 125 | |
| 126 virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { | |
| 127 NOTIMPLEMENTED(); | |
| 128 return net::ERR_FAILED; | |
| 129 } | |
| 130 virtual void Disconnect() OVERRIDE { | |
| 131 NOTIMPLEMENTED(); | |
| 132 } | |
| 133 virtual bool IsConnected() const OVERRIDE { | |
| 134 NOTIMPLEMENTED(); | |
| 135 return true; | |
| 136 } | |
| 137 virtual bool IsConnectedAndIdle() const OVERRIDE { | |
| 138 NOTIMPLEMENTED(); | |
| 139 return false; | |
| 140 } | |
| 141 virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { | |
| 142 NOTIMPLEMENTED(); | |
| 143 return net::ERR_FAILED; | |
| 144 } | |
| 145 virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { | |
| 146 NOTIMPLEMENTED(); | |
| 147 return net::ERR_FAILED; | |
| 148 } | |
| 149 virtual const net::BoundNetLog& NetLog() const OVERRIDE { | |
| 150 NOTIMPLEMENTED(); | |
| 151 return net_log_; | |
| 152 } | |
| 153 virtual void SetSubresourceSpeculation() OVERRIDE { | |
| 154 NOTIMPLEMENTED(); | |
| 155 } | |
| 156 virtual void SetOmniboxSpeculation() OVERRIDE { | |
| 157 NOTIMPLEMENTED(); | |
| 158 } | |
| 159 virtual bool WasEverUsed() const OVERRIDE { | |
| 160 return true; | |
| 161 } | |
| 162 virtual bool UsingTCPFastOpen() const OVERRIDE { | |
| 163 return false; | |
| 164 } | |
| 165 virtual int64 NumBytesRead() const OVERRIDE { | |
| 166 NOTIMPLEMENTED(); | |
| 167 return 0; | |
| 168 } | |
| 169 virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { | |
| 170 NOTIMPLEMENTED(); | |
| 171 return base::TimeDelta(); | |
| 172 } | |
| 173 virtual bool WasNpnNegotiated() const OVERRIDE { | |
| 174 return false; | |
| 175 } | |
| 176 virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { | |
| 177 return net::kProtoUnknown; | |
| 178 } | |
| 179 virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { | |
| 180 NOTIMPLEMENTED(); | |
| 181 return false; | |
| 182 } | |
| 183 | |
| 184 private: | |
| 185 MuxChannel* channel_; | |
| 186 | |
| 187 net::CompletionCallback read_callback_; | |
| 188 scoped_refptr<net::IOBuffer> read_buffer_; | |
| 189 int read_buffer_size_; | |
| 190 | |
| 191 bool write_pending_; | |
| 192 int write_result_; | |
| 193 net::CompletionCallback write_callback_; | |
| 194 | |
| 195 net::BoundNetLog net_log_; | |
| 196 | |
| 197 DISALLOW_COPY_AND_ASSIGN(MuxSocket); | |
| 198 }; | |
| 199 | |
| 200 | |
| 201 ChannelMultiplexer::MuxChannel::MuxChannel( | |
| 202 ChannelMultiplexer* multiplexer, | |
| 203 const std::string& name, | |
| 204 int send_id) | |
| 205 : multiplexer_(multiplexer), | |
| 206 name_(name), | |
| 207 send_id_(send_id), | |
| 208 id_sent_(false), | |
| 209 receive_id_(kChannelIdUnknown), | |
| 210 socket_(NULL) { | |
| 211 } | |
| 212 | |
| 213 ChannelMultiplexer::MuxChannel::~MuxChannel() { | |
| 214 // Socket must be destroyed before the channel. | |
| 215 DCHECK(!socket_); | |
| 216 STLDeleteElements(&pending_packets_); | |
| 217 } | |
| 218 | |
| 219 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() { | |
| 220 DCHECK(!socket_); // Can't create more than one socket per channel. | |
| 221 scoped_ptr<MuxSocket> result(new MuxSocket(this)); | |
| 222 socket_ = result.get(); | |
| 223 return result.PassAs<net::StreamSocket>(); | |
| 224 } | |
| 225 | |
| 226 void ChannelMultiplexer::MuxChannel::OnIncomingPacket( | |
| 227 scoped_ptr<MultiplexPacket> packet, | |
| 228 const base::Closure& done_task) { | |
| 229 DCHECK_EQ(packet->channel_id(), receive_id_); | |
| 230 if (packet->data().size() > 0) { | |
| 231 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); | |
| 232 if (socket_) { | |
| 233 // Notify the socket that we have more data. | |
| 234 socket_->OnPacketReceived(); | |
| 235 } | |
| 236 } | |
| 237 } | |
| 238 | |
| 239 void ChannelMultiplexer::MuxChannel::OnWriteFailed() { | |
| 240 if (socket_) | |
| 241 socket_->OnWriteFailed(); | |
| 242 } | |
| 243 | |
| 244 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { | |
| 245 DCHECK(socket_); | |
| 246 socket_ = NULL; | |
| 247 } | |
| 248 | |
| 249 bool ChannelMultiplexer::MuxChannel::DoWrite( | |
| 250 scoped_ptr<MultiplexPacket> packet, | |
| 251 const base::Closure& done_task) { | |
| 252 packet->set_channel_id(send_id_); | |
| 253 if (!id_sent_) { | |
| 254 packet->set_channel_name(name_); | |
| 255 id_sent_ = true; | |
| 256 } | |
| 257 return multiplexer_->DoWrite(packet.Pass(), done_task); | |
| 258 } | |
| 259 | |
| 260 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer, | |
| 261 int buffer_len) { | |
| 262 int pos = 0; | |
| 263 while (buffer_len > 0 && !pending_packets_.empty()) { | |
| 264 DCHECK(!pending_packets_.front()->is_empty()); | |
| 265 int result = pending_packets_.front()->Read( | |
| 266 buffer->data() + pos, buffer_len); | |
| 267 DCHECK_LE(result, buffer_len); | |
| 268 pos += result; | |
| 269 buffer_len -= pos; | |
| 270 if (pending_packets_.front()->is_empty()) { | |
| 271 delete pending_packets_.front(); | |
| 272 pending_packets_.erase(pending_packets_.begin()); | |
| 273 } | |
| 274 } | |
| 275 return pos; | |
| 276 } | |
| 277 | |
| 278 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel) | |
| 279 : channel_(channel), | |
| 280 read_buffer_size_(0), | |
| 281 write_pending_(false), | |
| 282 write_result_(0) { | |
| 283 } | |
| 284 | |
| 285 ChannelMultiplexer::MuxSocket::~MuxSocket() { | |
| 286 channel_->OnSocketDestroyed(); | |
| 287 } | |
| 288 | |
| 289 int ChannelMultiplexer::MuxSocket::Read( | |
| 290 net::IOBuffer* buffer, int buffer_len, | |
| 291 const net::CompletionCallback& callback) { | |
| 292 DCHECK(CalledOnValidThread()); | |
| 293 DCHECK(read_callback_.is_null()); | |
| 294 | |
| 295 int result = channel_->DoRead(buffer, buffer_len); | |
| 296 if (result == 0) { | |
| 297 read_buffer_ = buffer; | |
| 298 read_buffer_size_ = buffer_len; | |
| 299 read_callback_ = callback; | |
| 300 return net::ERR_IO_PENDING; | |
| 301 } | |
| 302 return result; | |
| 303 } | |
| 304 | |
| 305 int ChannelMultiplexer::MuxSocket::Write( | |
| 306 net::IOBuffer* buffer, int buffer_len, | |
| 307 const net::CompletionCallback& callback) { | |
| 308 DCHECK(CalledOnValidThread()); | |
| 309 | |
| 310 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); | |
| 311 size_t size = std::min(kMaxPacketSize, buffer_len); | |
| 312 packet->mutable_data()->assign(buffer->data(), size); | |
| 313 | |
| 314 write_pending_ = true; | |
| 315 bool result = channel_->DoWrite(packet.Pass(), base::Bind( | |
| 316 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); | |
| 317 | |
| 318 if (!result) { | |
| 319 // Cannot complete the write, e.g. if the connection has been terminated. | |
| 320 return net::ERR_FAILED; | |
| 321 } | |
| 322 | |
| 323 // OnWriteComplete() might be called above synchronously. | |
| 324 if (write_pending_) { | |
| 325 DCHECK(write_callback_.is_null()); | |
| 326 write_callback_ = callback; | |
| 327 write_result_ = size; | |
| 328 return net::ERR_IO_PENDING; | |
| 329 } | |
| 330 | |
| 331 return size; | |
| 332 } | |
| 333 | |
| 334 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { | |
| 335 write_pending_ = false; | |
| 336 if (!write_callback_.is_null()) { | |
| 337 net::CompletionCallback cb; | |
| 338 std::swap(cb, write_callback_); | |
| 339 cb.Run(write_result_); | |
| 340 } | |
| 341 } | |
| 342 | |
| 343 void ChannelMultiplexer::MuxSocket::OnWriteFailed() { | |
| 344 if (!write_callback_.is_null()) { | |
| 345 net::CompletionCallback cb; | |
| 346 std::swap(cb, write_callback_); | |
| 347 cb.Run(net::ERR_FAILED); | |
| 348 } | |
| 349 } | |
| 350 | |
| 351 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { | |
| 352 if (!read_callback_.is_null()) { | |
| 353 int result = channel_->DoRead(read_buffer_, read_buffer_size_); | |
| 354 read_buffer_ = NULL; | |
| 355 DCHECK_GT(result, 0); | |
| 356 net::CompletionCallback cb; | |
| 357 std::swap(cb, read_callback_); | |
| 358 cb.Run(result); | |
| 359 } | |
| 360 } | |
| 361 | |
| 362 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory, | |
| 363 const std::string& base_channel_name) | |
| 364 : base_channel_factory_(factory), | |
| 365 base_channel_name_(base_channel_name), | |
| 366 next_channel_id_(0), | |
| 367 destroyed_flag_(NULL) { | |
| 368 factory->CreateStreamChannel( | |
| 369 base_channel_name, | |
| 370 base::Bind(&ChannelMultiplexer::OnBaseChannelReady, | |
| 371 base::Unretained(this))); | |
| 372 } | |
| 373 | |
| 374 ChannelMultiplexer::~ChannelMultiplexer() { | |
| 375 DCHECK(pending_channels_.empty()); | |
| 376 STLDeleteValues(&channels_); | |
| 377 | |
| 378 // Cancel creation of the base channel if it hasn't finished. | |
| 379 if (base_channel_factory_) | |
| 380 base_channel_factory_->CancelChannelCreation(base_channel_name_); | |
| 381 | |
| 382 if (destroyed_flag_) | |
| 383 *destroyed_flag_ = true; | |
| 384 } | |
| 385 | |
| 386 void ChannelMultiplexer::CreateStreamChannel( | |
| 387 const std::string& name, | |
| 388 const StreamChannelCallback& callback) { | |
| 389 if (base_channel_.get()) { | |
| 390 // Already have |base_channel_|. Create new multiplexed channel | |
| 391 // synchronously. | |
| 392 callback.Run(GetOrCreateChannel(name)->CreateSocket()); | |
| 393 } else if (!base_channel_.get() && !base_channel_factory_) { | |
| 394 // Fail synchronously if we failed to create |base_channel_|. | |
| 395 callback.Run(scoped_ptr<net::StreamSocket>()); | |
| 396 } else { | |
| 397 // Still waiting for the |base_channel_|. | |
| 398 pending_channels_.push_back(PendingChannel(name, callback)); | |
| 399 } | |
| 400 } | |
| 401 | |
| 402 void ChannelMultiplexer::CreateDatagramChannel( | |
| 403 const std::string& name, | |
| 404 const DatagramChannelCallback& callback) { | |
| 405 NOTIMPLEMENTED(); | |
| 406 callback.Run(scoped_ptr<net::Socket>()); | |
| 407 } | |
| 408 | |
| 409 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) { | |
| 410 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | |
| 411 it != pending_channels_.end(); ++it) { | |
| 412 if (it->name == name) { | |
| 413 pending_channels_.erase(it); | |
| 414 return; | |
| 415 } | |
| 416 } | |
| 417 NOTREACHED(); | |
|
Wez
2012/08/07 21:44:45
Do you really want a NOTREACHED() here; it seems v
Sergey Ulanov
2012/08/07 22:25:14
Yes, you are right.
| |
| 418 } | |
| 419 | |
| 420 void ChannelMultiplexer::OnBaseChannelReady( | |
| 421 scoped_ptr<net::StreamSocket> socket) { | |
| 422 base_channel_factory_ = NULL; | |
| 423 base_channel_ = socket.Pass(); | |
| 424 | |
| 425 if (!base_channel_.get()) { | |
| 426 // Notify all callers that we can't create any channels. | |
| 427 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | |
| 428 it != pending_channels_.end(); ++it) { | |
| 429 it->callback.Run(scoped_ptr<net::StreamSocket>()); | |
| 430 } | |
| 431 pending_channels_.clear(); | |
| 432 return; | |
| 433 } | |
| 434 | |
| 435 // Initialize reader and writer. | |
| 436 reader_.Init(base_channel_.get(), | |
| 437 base::Bind(&ChannelMultiplexer::OnIncomingPacket, | |
| 438 base::Unretained(this))); | |
| 439 writer_.Init(base_channel_.get(), | |
| 440 base::Bind(&ChannelMultiplexer::OnWriteFailed, | |
| 441 base::Unretained(this))); | |
| 442 | |
| 443 // Now create all pending channels. | |
| 444 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | |
| 445 it != pending_channels_.end(); ++it) { | |
| 446 it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket()); | |
| 447 } | |
| 448 pending_channels_.clear(); | |
| 449 } | |
| 450 | |
| 451 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( | |
| 452 const std::string& name) { | |
| 453 // Check if we already have a channel with the requested name. | |
| 454 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); | |
| 455 if (it != channels_.end()) | |
| 456 return it->second; | |
| 457 | |
| 458 // Create a new channel if we haven't found existing one. | |
| 459 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); | |
| 460 ++next_channel_id_; | |
| 461 channels_[channel->name()] = channel; | |
| 462 return channel; | |
| 463 } | |
| 464 | |
| 465 | |
| 466 void ChannelMultiplexer::OnWriteFailed(int error) { | |
| 467 bool destroyed = false; | |
| 468 destroyed_flag_ = &destroyed; | |
| 469 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); | |
| 470 it != channels_.end(); ++it) { | |
| 471 it->second->OnWriteFailed(); | |
| 472 if (destroyed) | |
| 473 return; | |
| 474 } | |
| 475 destroyed_flag_ = NULL; | |
| 476 } | |
| 477 | |
| 478 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, | |
| 479 const base::Closure& done_task) { | |
| 480 if (!packet->has_channel_id()) { | |
| 481 LOG(ERROR) << "Received packet without channel_id."; | |
| 482 done_task.Run(); | |
| 483 return; | |
| 484 } | |
| 485 | |
| 486 int receive_id = packet->channel_id(); | |
| 487 MuxChannel* channel = NULL; | |
| 488 std::map<int, MuxChannel*>::iterator it = | |
| 489 channels_by_receive_id_.find(receive_id); | |
| 490 if (it != channels_by_receive_id_.end()) { | |
| 491 channel = it->second; | |
| 492 } else { | |
| 493 // This is a new |channel_id| we haven't seen before. Look it up by name. | |
| 494 if (!packet->has_channel_name()) { | |
| 495 LOG(ERROR) << "Received packet with unknown channel_id and " | |
| 496 "without channel_name."; | |
| 497 done_task.Run(); | |
| 498 return; | |
| 499 } | |
| 500 channel = GetOrCreateChannel(packet->channel_name()); | |
| 501 channel->set_receive_id(receive_id); | |
| 502 channels_by_receive_id_[receive_id] = channel; | |
| 503 } | |
| 504 | |
| 505 channel->OnIncomingPacket(packet.Pass(), done_task); | |
| 506 } | |
| 507 | |
| 508 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, | |
| 509 const base::Closure& done_task) { | |
| 510 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); | |
| 511 } | |
| 512 | |
| 513 } // namespace protocol | |
| 514 } // namespace remoting | |
| OLD | NEW |