| OLD | NEW |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "remoting/protocol/channel_multiplexer.h" | 5 #include "remoting/protocol/channel_multiplexer.h" |
| 6 | 6 |
| 7 #include <string.h> | 7 #include <string.h> |
| 8 | 8 |
| 9 #include "base/bind.h" | 9 #include "base/bind.h" |
| 10 #include "base/callback.h" | 10 #include "base/callback.h" |
| 11 #include "base/callback_helpers.h" |
| 11 #include "base/location.h" | 12 #include "base/location.h" |
| 12 #include "base/single_thread_task_runner.h" | 13 #include "base/single_thread_task_runner.h" |
| 13 #include "base/stl_util.h" | 14 #include "base/stl_util.h" |
| 14 #include "base/thread_task_runner_handle.h" | 15 #include "base/thread_task_runner_handle.h" |
| 15 #include "net/base/net_errors.h" | 16 #include "net/base/net_errors.h" |
| 16 #include "net/socket/stream_socket.h" | 17 #include "net/socket/stream_socket.h" |
| 17 #include "remoting/protocol/message_serialization.h" | 18 #include "remoting/protocol/message_serialization.h" |
| 18 | 19 |
| 19 namespace remoting { | 20 namespace remoting { |
| 20 namespace protocol { | 21 namespace protocol { |
| (...skipping 51 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 72 ~MuxChannel(); | 73 ~MuxChannel(); |
| 73 | 74 |
| 74 const std::string& name() { return name_; } | 75 const std::string& name() { return name_; } |
| 75 int receive_id() { return receive_id_; } | 76 int receive_id() { return receive_id_; } |
| 76 void set_receive_id(int id) { receive_id_ = id; } | 77 void set_receive_id(int id) { receive_id_ = id; } |
| 77 | 78 |
| 78 // Called by ChannelMultiplexer. | 79 // Called by ChannelMultiplexer. |
| 79 scoped_ptr<net::StreamSocket> CreateSocket(); | 80 scoped_ptr<net::StreamSocket> CreateSocket(); |
| 80 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, | 81 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, |
| 81 const base::Closure& done_task); | 82 const base::Closure& done_task); |
| 82 void OnWriteFailed(); | 83 void OnBaseChannelError(int error); |
| 83 | 84 |
| 84 // Called by MuxSocket. | 85 // Called by MuxSocket. |
| 85 void OnSocketDestroyed(); | 86 void OnSocketDestroyed(); |
| 86 bool DoWrite(scoped_ptr<MultiplexPacket> packet, | 87 bool DoWrite(scoped_ptr<MultiplexPacket> packet, |
| 87 const base::Closure& done_task); | 88 const base::Closure& done_task); |
| 88 int DoRead(net::IOBuffer* buffer, int buffer_len); | 89 int DoRead(net::IOBuffer* buffer, int buffer_len); |
| 89 | 90 |
| 90 private: | 91 private: |
| 91 ChannelMultiplexer* multiplexer_; | 92 ChannelMultiplexer* multiplexer_; |
| 92 std::string name_; | 93 std::string name_; |
| 93 int send_id_; | 94 int send_id_; |
| 94 bool id_sent_; | 95 bool id_sent_; |
| 95 int receive_id_; | 96 int receive_id_; |
| 96 MuxSocket* socket_; | 97 MuxSocket* socket_; |
| 97 std::list<PendingPacket*> pending_packets_; | 98 std::list<PendingPacket*> pending_packets_; |
| 98 | 99 |
| 99 DISALLOW_COPY_AND_ASSIGN(MuxChannel); | 100 DISALLOW_COPY_AND_ASSIGN(MuxChannel); |
| 100 }; | 101 }; |
| 101 | 102 |
| 102 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, | 103 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, |
| 103 public base::NonThreadSafe, | 104 public base::NonThreadSafe, |
| 104 public base::SupportsWeakPtr<MuxSocket> { | 105 public base::SupportsWeakPtr<MuxSocket> { |
| 105 public: | 106 public: |
| 106 MuxSocket(MuxChannel* channel); | 107 MuxSocket(MuxChannel* channel); |
| 107 ~MuxSocket() override; | 108 ~MuxSocket() override; |
| 108 | 109 |
| 109 void OnWriteComplete(); | 110 void OnWriteComplete(); |
| 110 void OnWriteFailed(); | 111 void OnBaseChannelError(int error); |
| 111 void OnPacketReceived(); | 112 void OnPacketReceived(); |
| 112 | 113 |
| 113 // net::StreamSocket interface. | 114 // net::StreamSocket interface. |
| 114 int Read(net::IOBuffer* buffer, | 115 int Read(net::IOBuffer* buffer, |
| 115 int buffer_len, | 116 int buffer_len, |
| 116 const net::CompletionCallback& callback) override; | 117 const net::CompletionCallback& callback) override; |
| 117 int Write(net::IOBuffer* buffer, | 118 int Write(net::IOBuffer* buffer, |
| 118 int buffer_len, | 119 int buffer_len, |
| 119 const net::CompletionCallback& callback) override; | 120 const net::CompletionCallback& callback) override; |
| 120 | 121 |
| (...skipping 40 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 161 return net::kProtoUnknown; | 162 return net::kProtoUnknown; |
| 162 } | 163 } |
| 163 bool GetSSLInfo(net::SSLInfo* ssl_info) override { | 164 bool GetSSLInfo(net::SSLInfo* ssl_info) override { |
| 164 NOTIMPLEMENTED(); | 165 NOTIMPLEMENTED(); |
| 165 return false; | 166 return false; |
| 166 } | 167 } |
| 167 | 168 |
| 168 private: | 169 private: |
| 169 MuxChannel* channel_; | 170 MuxChannel* channel_; |
| 170 | 171 |
| 172 int base_channel_error_ = net::OK; |
| 173 |
| 171 net::CompletionCallback read_callback_; | 174 net::CompletionCallback read_callback_; |
| 172 scoped_refptr<net::IOBuffer> read_buffer_; | 175 scoped_refptr<net::IOBuffer> read_buffer_; |
| 173 int read_buffer_size_; | 176 int read_buffer_size_; |
| 174 | 177 |
| 175 bool write_pending_; | 178 bool write_pending_; |
| 176 int write_result_; | 179 int write_result_; |
| 177 net::CompletionCallback write_callback_; | 180 net::CompletionCallback write_callback_; |
| 178 | 181 |
| 179 net::BoundNetLog net_log_; | 182 net::BoundNetLog net_log_; |
| 180 | 183 |
| (...skipping 32 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 213 DCHECK_EQ(packet->channel_id(), receive_id_); | 216 DCHECK_EQ(packet->channel_id(), receive_id_); |
| 214 if (packet->data().size() > 0) { | 217 if (packet->data().size() > 0) { |
| 215 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); | 218 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); |
| 216 if (socket_) { | 219 if (socket_) { |
| 217 // Notify the socket that we have more data. | 220 // Notify the socket that we have more data. |
| 218 socket_->OnPacketReceived(); | 221 socket_->OnPacketReceived(); |
| 219 } | 222 } |
| 220 } | 223 } |
| 221 } | 224 } |
| 222 | 225 |
| 223 void ChannelMultiplexer::MuxChannel::OnWriteFailed() { | 226 void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) { |
| 224 if (socket_) | 227 if (socket_) |
| 225 socket_->OnWriteFailed(); | 228 socket_->OnBaseChannelError(error); |
| 226 } | 229 } |
| 227 | 230 |
| 228 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { | 231 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { |
| 229 DCHECK(socket_); | 232 DCHECK(socket_); |
| 230 socket_ = nullptr; | 233 socket_ = nullptr; |
| 231 } | 234 } |
| 232 | 235 |
| 233 bool ChannelMultiplexer::MuxChannel::DoWrite( | 236 bool ChannelMultiplexer::MuxChannel::DoWrite( |
| 234 scoped_ptr<MultiplexPacket> packet, | 237 scoped_ptr<MultiplexPacket> packet, |
| 235 const base::Closure& done_task) { | 238 const base::Closure& done_task) { |
| (...skipping 33 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 269 ChannelMultiplexer::MuxSocket::~MuxSocket() { | 272 ChannelMultiplexer::MuxSocket::~MuxSocket() { |
| 270 channel_->OnSocketDestroyed(); | 273 channel_->OnSocketDestroyed(); |
| 271 } | 274 } |
| 272 | 275 |
| 273 int ChannelMultiplexer::MuxSocket::Read( | 276 int ChannelMultiplexer::MuxSocket::Read( |
| 274 net::IOBuffer* buffer, int buffer_len, | 277 net::IOBuffer* buffer, int buffer_len, |
| 275 const net::CompletionCallback& callback) { | 278 const net::CompletionCallback& callback) { |
| 276 DCHECK(CalledOnValidThread()); | 279 DCHECK(CalledOnValidThread()); |
| 277 DCHECK(read_callback_.is_null()); | 280 DCHECK(read_callback_.is_null()); |
| 278 | 281 |
| 282 if (base_channel_error_ != net::OK) |
| 283 return base_channel_error_; |
| 284 |
| 279 int result = channel_->DoRead(buffer, buffer_len); | 285 int result = channel_->DoRead(buffer, buffer_len); |
| 280 if (result == 0) { | 286 if (result == 0) { |
| 281 read_buffer_ = buffer; | 287 read_buffer_ = buffer; |
| 282 read_buffer_size_ = buffer_len; | 288 read_buffer_size_ = buffer_len; |
| 283 read_callback_ = callback; | 289 read_callback_ = callback; |
| 284 return net::ERR_IO_PENDING; | 290 return net::ERR_IO_PENDING; |
| 285 } | 291 } |
| 286 return result; | 292 return result; |
| 287 } | 293 } |
| 288 | 294 |
| 289 int ChannelMultiplexer::MuxSocket::Write( | 295 int ChannelMultiplexer::MuxSocket::Write( |
| 290 net::IOBuffer* buffer, int buffer_len, | 296 net::IOBuffer* buffer, int buffer_len, |
| 291 const net::CompletionCallback& callback) { | 297 const net::CompletionCallback& callback) { |
| 292 DCHECK(CalledOnValidThread()); | 298 DCHECK(CalledOnValidThread()); |
| 299 DCHECK(write_callback_.is_null()); |
| 300 |
| 301 if (base_channel_error_ != net::OK) |
| 302 return base_channel_error_; |
| 293 | 303 |
| 294 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); | 304 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); |
| 295 size_t size = std::min(kMaxPacketSize, buffer_len); | 305 size_t size = std::min(kMaxPacketSize, buffer_len); |
| 296 packet->mutable_data()->assign(buffer->data(), size); | 306 packet->mutable_data()->assign(buffer->data(), size); |
| 297 | 307 |
| 298 write_pending_ = true; | 308 write_pending_ = true; |
| 299 bool result = channel_->DoWrite(packet.Pass(), base::Bind( | 309 bool result = channel_->DoWrite(packet.Pass(), base::Bind( |
| 300 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); | 310 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); |
| 301 | 311 |
| 302 if (!result) { | 312 if (!result) { |
| 303 // Cannot complete the write, e.g. if the connection has been terminated. | 313 // Cannot complete the write, e.g. if the connection has been terminated. |
| 304 return net::ERR_FAILED; | 314 return net::ERR_FAILED; |
| 305 } | 315 } |
| 306 | 316 |
| 307 // OnWriteComplete() might be called above synchronously. | 317 // OnWriteComplete() might be called above synchronously. |
| 308 if (write_pending_) { | 318 if (write_pending_) { |
| 309 DCHECK(write_callback_.is_null()); | 319 DCHECK(write_callback_.is_null()); |
| 310 write_callback_ = callback; | 320 write_callback_ = callback; |
| 311 write_result_ = size; | 321 write_result_ = size; |
| 312 return net::ERR_IO_PENDING; | 322 return net::ERR_IO_PENDING; |
| 313 } | 323 } |
| 314 | 324 |
| 315 return size; | 325 return size; |
| 316 } | 326 } |
| 317 | 327 |
| 318 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { | 328 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { |
| 319 write_pending_ = false; | 329 write_pending_ = false; |
| 320 if (!write_callback_.is_null()) { | 330 if (!write_callback_.is_null()) |
| 321 net::CompletionCallback cb; | 331 base::ResetAndReturn(&write_callback_).Run(write_result_); |
| 322 std::swap(cb, write_callback_); | 332 |
| 323 cb.Run(write_result_); | |
| 324 } | |
| 325 } | 333 } |
| 326 | 334 |
| 327 void ChannelMultiplexer::MuxSocket::OnWriteFailed() { | 335 void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) { |
| 328 if (!write_callback_.is_null()) { | 336 base_channel_error_ = error; |
| 329 net::CompletionCallback cb; | 337 |
| 330 std::swap(cb, write_callback_); | 338 // Here only one of the read and write callbacks is called if both of them are |
| 331 cb.Run(net::ERR_FAILED); | 339 // pending. Ideally both of them should be called in that case, but that would |
| 340 // require the second one to be called asynchronously which would complicate |
| 341 // this code. Channels handle read and write errors the same way (see |
| 342 // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the |
| 343 // callbacks is enough. |
| 344 |
| 345 if (!read_callback_.is_null()) { |
| 346 base::ResetAndReturn(&read_callback_).Run(error); |
| 347 return; |
| 332 } | 348 } |
| 349 |
| 350 if (!write_callback_.is_null()) |
| 351 base::ResetAndReturn(&write_callback_).Run(error); |
| 333 } | 352 } |
| 334 | 353 |
| 335 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { | 354 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { |
| 336 if (!read_callback_.is_null()) { | 355 if (!read_callback_.is_null()) { |
| 337 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_); | 356 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_); |
| 338 read_buffer_ = nullptr; | 357 read_buffer_ = nullptr; |
| 339 DCHECK_GT(result, 0); | 358 DCHECK_GT(result, 0); |
| 340 net::CompletionCallback cb; | 359 base::ResetAndReturn(&read_callback_).Run(result); |
| 341 std::swap(cb, read_callback_); | |
| 342 cb.Run(result); | |
| 343 } | 360 } |
| 344 } | 361 } |
| 345 | 362 |
| 346 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory, | 363 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory, |
| 347 const std::string& base_channel_name) | 364 const std::string& base_channel_name) |
| 348 : base_channel_factory_(factory), | 365 : base_channel_factory_(factory), |
| 349 base_channel_name_(base_channel_name), | 366 base_channel_name_(base_channel_name), |
| 350 next_channel_id_(0), | 367 next_channel_id_(0), |
| 351 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket, | 368 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket, |
| 352 base::Unretained(this)), | 369 base::Unretained(this)), |
| (...skipping 43 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 396 } | 413 } |
| 397 } | 414 } |
| 398 | 415 |
| 399 void ChannelMultiplexer::OnBaseChannelReady( | 416 void ChannelMultiplexer::OnBaseChannelReady( |
| 400 scoped_ptr<net::StreamSocket> socket) { | 417 scoped_ptr<net::StreamSocket> socket) { |
| 401 base_channel_factory_ = nullptr; | 418 base_channel_factory_ = nullptr; |
| 402 base_channel_ = socket.Pass(); | 419 base_channel_ = socket.Pass(); |
| 403 | 420 |
| 404 if (base_channel_.get()) { | 421 if (base_channel_.get()) { |
| 405 // Initialize reader and writer. | 422 // Initialize reader and writer. |
| 406 reader_.StartReading(base_channel_.get()); | 423 reader_.StartReading(base_channel_.get(), |
| 424 base::Bind(&ChannelMultiplexer::OnBaseChannelError, |
| 425 base::Unretained(this))); |
| 407 writer_.Init(base_channel_.get(), | 426 writer_.Init(base_channel_.get(), |
| 408 base::Bind(&ChannelMultiplexer::OnWriteFailed, | 427 base::Bind(&ChannelMultiplexer::OnBaseChannelError, |
| 409 base::Unretained(this))); | 428 base::Unretained(this))); |
| 410 } | 429 } |
| 411 | 430 |
| 412 DoCreatePendingChannels(); | 431 DoCreatePendingChannels(); |
| 413 } | 432 } |
| 414 | 433 |
| 415 void ChannelMultiplexer::DoCreatePendingChannels() { | 434 void ChannelMultiplexer::DoCreatePendingChannels() { |
| 416 if (pending_channels_.empty()) | 435 if (pending_channels_.empty()) |
| 417 return; | 436 return; |
| 418 | 437 |
| (...skipping 21 matching lines...) Expand all Loading... |
| 440 return it->second; | 459 return it->second; |
| 441 | 460 |
| 442 // Create a new channel if we haven't found existing one. | 461 // Create a new channel if we haven't found existing one. |
| 443 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); | 462 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); |
| 444 ++next_channel_id_; | 463 ++next_channel_id_; |
| 445 channels_[channel->name()] = channel; | 464 channels_[channel->name()] = channel; |
| 446 return channel; | 465 return channel; |
| 447 } | 466 } |
| 448 | 467 |
| 449 | 468 |
| 450 void ChannelMultiplexer::OnWriteFailed(int error) { | 469 void ChannelMultiplexer::OnBaseChannelError(int error) { |
| 451 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); | 470 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); |
| 452 it != channels_.end(); ++it) { | 471 it != channels_.end(); ++it) { |
| 453 base::ThreadTaskRunnerHandle::Get()->PostTask( | 472 base::ThreadTaskRunnerHandle::Get()->PostTask( |
| 454 FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed, | 473 FROM_HERE, |
| 455 weak_factory_.GetWeakPtr(), it->second->name())); | 474 base::Bind(&ChannelMultiplexer::NotifyBaseChannelError, |
| 475 weak_factory_.GetWeakPtr(), it->second->name(), error)); |
| 456 } | 476 } |
| 457 } | 477 } |
| 458 | 478 |
| 459 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) { | 479 void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name, |
| 480 int error) { |
| 460 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); | 481 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); |
| 461 if (it != channels_.end()) { | 482 if (it != channels_.end()) |
| 462 it->second->OnWriteFailed(); | 483 it->second->OnBaseChannelError(error); |
| 463 } | |
| 464 } | 484 } |
| 465 | 485 |
| 466 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, | 486 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, |
| 467 const base::Closure& done_task) { | 487 const base::Closure& done_task) { |
| 468 DCHECK(packet->has_channel_id()); | 488 DCHECK(packet->has_channel_id()); |
| 469 if (!packet->has_channel_id()) { | 489 if (!packet->has_channel_id()) { |
| 470 LOG(ERROR) << "Received packet without channel_id."; | 490 LOG(ERROR) << "Received packet without channel_id."; |
| 471 done_task.Run(); | 491 done_task.Run(); |
| 472 return; | 492 return; |
| 473 } | 493 } |
| (...skipping 20 matching lines...) Expand all Loading... |
| 494 channel->OnIncomingPacket(packet.Pass(), done_task); | 514 channel->OnIncomingPacket(packet.Pass(), done_task); |
| 495 } | 515 } |
| 496 | 516 |
| 497 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, | 517 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, |
| 498 const base::Closure& done_task) { | 518 const base::Closure& done_task) { |
| 499 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); | 519 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); |
| 500 } | 520 } |
| 501 | 521 |
| 502 } // namespace protocol | 522 } // namespace protocol |
| 503 } // namespace remoting | 523 } // namespace remoting |
| OLD | NEW |