OLD | NEW |
---|---|
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "remoting/protocol/channel_multiplexer.h" | 5 #include "remoting/protocol/channel_multiplexer.h" |
6 | 6 |
7 #include <string.h> | 7 #include <string.h> |
8 | 8 |
9 #include "base/bind.h" | 9 #include "base/bind.h" |
10 #include "base/callback.h" | 10 #include "base/callback.h" |
11 #include "base/callback_helpers.h" | |
11 #include "base/location.h" | 12 #include "base/location.h" |
12 #include "base/single_thread_task_runner.h" | 13 #include "base/single_thread_task_runner.h" |
13 #include "base/stl_util.h" | 14 #include "base/stl_util.h" |
14 #include "base/thread_task_runner_handle.h" | 15 #include "base/thread_task_runner_handle.h" |
15 #include "net/base/net_errors.h" | 16 #include "net/base/net_errors.h" |
16 #include "net/socket/stream_socket.h" | 17 #include "net/socket/stream_socket.h" |
17 #include "remoting/protocol/message_serialization.h" | 18 #include "remoting/protocol/message_serialization.h" |
18 | 19 |
19 namespace remoting { | 20 namespace remoting { |
20 namespace protocol { | 21 namespace protocol { |
(...skipping 51 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
72 ~MuxChannel(); | 73 ~MuxChannel(); |
73 | 74 |
74 const std::string& name() { return name_; } | 75 const std::string& name() { return name_; } |
75 int receive_id() { return receive_id_; } | 76 int receive_id() { return receive_id_; } |
76 void set_receive_id(int id) { receive_id_ = id; } | 77 void set_receive_id(int id) { receive_id_ = id; } |
77 | 78 |
78 // Called by ChannelMultiplexer. | 79 // Called by ChannelMultiplexer. |
79 scoped_ptr<net::StreamSocket> CreateSocket(); | 80 scoped_ptr<net::StreamSocket> CreateSocket(); |
80 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, | 81 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, |
81 const base::Closure& done_task); | 82 const base::Closure& done_task); |
82 void OnWriteFailed(); | 83 void OnBaseChannelError(int error); |
83 | 84 |
84 // Called by MuxSocket. | 85 // Called by MuxSocket. |
85 void OnSocketDestroyed(); | 86 void OnSocketDestroyed(); |
86 bool DoWrite(scoped_ptr<MultiplexPacket> packet, | 87 bool DoWrite(scoped_ptr<MultiplexPacket> packet, |
87 const base::Closure& done_task); | 88 const base::Closure& done_task); |
88 int DoRead(net::IOBuffer* buffer, int buffer_len); | 89 int DoRead(net::IOBuffer* buffer, int buffer_len); |
89 | 90 |
90 private: | 91 private: |
91 ChannelMultiplexer* multiplexer_; | 92 ChannelMultiplexer* multiplexer_; |
92 std::string name_; | 93 std::string name_; |
93 int send_id_; | 94 int send_id_; |
94 bool id_sent_; | 95 bool id_sent_; |
95 int receive_id_; | 96 int receive_id_; |
96 MuxSocket* socket_; | 97 MuxSocket* socket_; |
97 std::list<PendingPacket*> pending_packets_; | 98 std::list<PendingPacket*> pending_packets_; |
98 | 99 |
99 DISALLOW_COPY_AND_ASSIGN(MuxChannel); | 100 DISALLOW_COPY_AND_ASSIGN(MuxChannel); |
100 }; | 101 }; |
101 | 102 |
102 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, | 103 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, |
103 public base::NonThreadSafe, | 104 public base::NonThreadSafe, |
104 public base::SupportsWeakPtr<MuxSocket> { | 105 public base::SupportsWeakPtr<MuxSocket> { |
105 public: | 106 public: |
106 MuxSocket(MuxChannel* channel); | 107 MuxSocket(MuxChannel* channel); |
107 ~MuxSocket() override; | 108 ~MuxSocket() override; |
108 | 109 |
109 void OnWriteComplete(); | 110 void OnWriteComplete(); |
110 void OnWriteFailed(); | 111 void OnBaseChannelError(int error); |
111 void OnPacketReceived(); | 112 void OnPacketReceived(); |
112 | 113 |
113 // net::StreamSocket interface. | 114 // net::StreamSocket interface. |
114 int Read(net::IOBuffer* buffer, | 115 int Read(net::IOBuffer* buffer, |
115 int buffer_len, | 116 int buffer_len, |
116 const net::CompletionCallback& callback) override; | 117 const net::CompletionCallback& callback) override; |
117 int Write(net::IOBuffer* buffer, | 118 int Write(net::IOBuffer* buffer, |
118 int buffer_len, | 119 int buffer_len, |
119 const net::CompletionCallback& callback) override; | 120 const net::CompletionCallback& callback) override; |
120 | 121 |
(...skipping 40 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
161 return net::kProtoUnknown; | 162 return net::kProtoUnknown; |
162 } | 163 } |
163 bool GetSSLInfo(net::SSLInfo* ssl_info) override { | 164 bool GetSSLInfo(net::SSLInfo* ssl_info) override { |
164 NOTIMPLEMENTED(); | 165 NOTIMPLEMENTED(); |
165 return false; | 166 return false; |
166 } | 167 } |
167 | 168 |
168 private: | 169 private: |
169 MuxChannel* channel_; | 170 MuxChannel* channel_; |
170 | 171 |
172 int base_channel_error_ = net::OK; | |
173 | |
171 net::CompletionCallback read_callback_; | 174 net::CompletionCallback read_callback_; |
172 scoped_refptr<net::IOBuffer> read_buffer_; | 175 scoped_refptr<net::IOBuffer> read_buffer_; |
173 int read_buffer_size_; | 176 int read_buffer_size_; |
174 | 177 |
175 bool write_pending_; | 178 bool write_pending_; |
176 int write_result_; | 179 int write_result_; |
177 net::CompletionCallback write_callback_; | 180 net::CompletionCallback write_callback_; |
178 | 181 |
179 net::BoundNetLog net_log_; | 182 net::BoundNetLog net_log_; |
180 | 183 |
(...skipping 32 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
213 DCHECK_EQ(packet->channel_id(), receive_id_); | 216 DCHECK_EQ(packet->channel_id(), receive_id_); |
214 if (packet->data().size() > 0) { | 217 if (packet->data().size() > 0) { |
215 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); | 218 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); |
216 if (socket_) { | 219 if (socket_) { |
217 // Notify the socket that we have more data. | 220 // Notify the socket that we have more data. |
218 socket_->OnPacketReceived(); | 221 socket_->OnPacketReceived(); |
219 } | 222 } |
220 } | 223 } |
221 } | 224 } |
222 | 225 |
223 void ChannelMultiplexer::MuxChannel::OnWriteFailed() { | 226 void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) { |
224 if (socket_) | 227 if (socket_) |
225 socket_->OnWriteFailed(); | 228 socket_->OnBaseChannelError(error); |
226 } | 229 } |
227 | 230 |
228 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { | 231 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { |
229 DCHECK(socket_); | 232 DCHECK(socket_); |
230 socket_ = nullptr; | 233 socket_ = nullptr; |
231 } | 234 } |
232 | 235 |
233 bool ChannelMultiplexer::MuxChannel::DoWrite( | 236 bool ChannelMultiplexer::MuxChannel::DoWrite( |
234 scoped_ptr<MultiplexPacket> packet, | 237 scoped_ptr<MultiplexPacket> packet, |
235 const base::Closure& done_task) { | 238 const base::Closure& done_task) { |
(...skipping 33 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
269 ChannelMultiplexer::MuxSocket::~MuxSocket() { | 272 ChannelMultiplexer::MuxSocket::~MuxSocket() { |
270 channel_->OnSocketDestroyed(); | 273 channel_->OnSocketDestroyed(); |
271 } | 274 } |
272 | 275 |
273 int ChannelMultiplexer::MuxSocket::Read( | 276 int ChannelMultiplexer::MuxSocket::Read( |
274 net::IOBuffer* buffer, int buffer_len, | 277 net::IOBuffer* buffer, int buffer_len, |
275 const net::CompletionCallback& callback) { | 278 const net::CompletionCallback& callback) { |
276 DCHECK(CalledOnValidThread()); | 279 DCHECK(CalledOnValidThread()); |
277 DCHECK(read_callback_.is_null()); | 280 DCHECK(read_callback_.is_null()); |
278 | 281 |
282 if (base_channel_error_ != net::OK) | |
283 return base_channel_error_; | |
284 | |
279 int result = channel_->DoRead(buffer, buffer_len); | 285 int result = channel_->DoRead(buffer, buffer_len); |
280 if (result == 0) { | 286 if (result == 0) { |
281 read_buffer_ = buffer; | 287 read_buffer_ = buffer; |
282 read_buffer_size_ = buffer_len; | 288 read_buffer_size_ = buffer_len; |
283 read_callback_ = callback; | 289 read_callback_ = callback; |
284 return net::ERR_IO_PENDING; | 290 return net::ERR_IO_PENDING; |
285 } | 291 } |
286 return result; | 292 return result; |
287 } | 293 } |
288 | 294 |
289 int ChannelMultiplexer::MuxSocket::Write( | 295 int ChannelMultiplexer::MuxSocket::Write( |
290 net::IOBuffer* buffer, int buffer_len, | 296 net::IOBuffer* buffer, int buffer_len, |
291 const net::CompletionCallback& callback) { | 297 const net::CompletionCallback& callback) { |
292 DCHECK(CalledOnValidThread()); | 298 DCHECK(CalledOnValidThread()); |
299 DCHECK(write_callback_.is_null()); | |
300 | |
301 if (base_channel_error_ != net::OK) | |
302 return base_channel_error_; | |
293 | 303 |
294 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); | 304 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); |
295 size_t size = std::min(kMaxPacketSize, buffer_len); | 305 size_t size = std::min(kMaxPacketSize, buffer_len); |
296 packet->mutable_data()->assign(buffer->data(), size); | 306 packet->mutable_data()->assign(buffer->data(), size); |
297 | 307 |
298 write_pending_ = true; | 308 write_pending_ = true; |
299 bool result = channel_->DoWrite(packet.Pass(), base::Bind( | 309 bool result = channel_->DoWrite(packet.Pass(), base::Bind( |
300 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); | 310 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); |
301 | 311 |
302 if (!result) { | 312 if (!result) { |
303 // Cannot complete the write, e.g. if the connection has been terminated. | 313 // Cannot complete the write, e.g. if the connection has been terminated. |
304 return net::ERR_FAILED; | 314 return net::ERR_FAILED; |
305 } | 315 } |
306 | 316 |
307 // OnWriteComplete() might be called above synchronously. | 317 // OnWriteComplete() might be called above synchronously. |
308 if (write_pending_) { | 318 if (write_pending_) { |
309 DCHECK(write_callback_.is_null()); | 319 DCHECK(write_callback_.is_null()); |
310 write_callback_ = callback; | 320 write_callback_ = callback; |
311 write_result_ = size; | 321 write_result_ = size; |
312 return net::ERR_IO_PENDING; | 322 return net::ERR_IO_PENDING; |
313 } | 323 } |
314 | 324 |
315 return size; | 325 return size; |
316 } | 326 } |
317 | 327 |
318 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { | 328 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { |
319 write_pending_ = false; | 329 write_pending_ = false; |
320 if (!write_callback_.is_null()) { | 330 if (!write_callback_.is_null()) |
321 net::CompletionCallback cb; | 331 base::ResetAndReturn(&write_callback_).Run(write_result_); |
322 std::swap(cb, write_callback_); | 332 |
323 cb.Run(write_result_); | |
324 } | |
325 } | 333 } |
326 | 334 |
327 void ChannelMultiplexer::MuxSocket::OnWriteFailed() { | 335 void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) { |
328 if (!write_callback_.is_null()) { | 336 base_channel_error_ = error; |
329 net::CompletionCallback cb; | 337 |
330 std::swap(cb, write_callback_); | 338 // Here only one of the read and write callbacks is called if both of them are |
331 cb.Run(net::ERR_FAILED); | 339 // pending. Ideally both of them should be called in that case, but that would |
340 // require the second one to be called asynchronously which would complicate | |
341 // this code. Channels handle read and write errors the same way (see | |
342 // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the | |
343 // callbacks is enough. | |
Jamie
2015/05/13 18:25:17
But you're only resetting one of the callbacks in
Sergey Ulanov
2015/05/14 00:15:21
It's better to leave write_callback_ not null in t
| |
344 | |
345 if (!read_callback_.is_null()) { | |
346 base::ResetAndReturn(&read_callback_).Run(error); | |
347 return; | |
332 } | 348 } |
349 | |
350 if (!write_callback_.is_null()) | |
351 base::ResetAndReturn(&write_callback_).Run(error); | |
333 } | 352 } |
334 | 353 |
335 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { | 354 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { |
336 if (!read_callback_.is_null()) { | 355 if (!read_callback_.is_null()) { |
337 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_); | 356 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_); |
338 read_buffer_ = nullptr; | 357 read_buffer_ = nullptr; |
339 DCHECK_GT(result, 0); | 358 DCHECK_GT(result, 0); |
340 net::CompletionCallback cb; | 359 base::ResetAndReturn(&read_callback_).Run(result); |
341 std::swap(cb, read_callback_); | |
342 cb.Run(result); | |
343 } | 360 } |
344 } | 361 } |
345 | 362 |
346 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory, | 363 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory, |
347 const std::string& base_channel_name) | 364 const std::string& base_channel_name) |
348 : base_channel_factory_(factory), | 365 : base_channel_factory_(factory), |
349 base_channel_name_(base_channel_name), | 366 base_channel_name_(base_channel_name), |
350 next_channel_id_(0), | 367 next_channel_id_(0), |
351 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket, | 368 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket, |
352 base::Unretained(this)), | 369 base::Unretained(this)), |
(...skipping 43 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
396 } | 413 } |
397 } | 414 } |
398 | 415 |
399 void ChannelMultiplexer::OnBaseChannelReady( | 416 void ChannelMultiplexer::OnBaseChannelReady( |
400 scoped_ptr<net::StreamSocket> socket) { | 417 scoped_ptr<net::StreamSocket> socket) { |
401 base_channel_factory_ = nullptr; | 418 base_channel_factory_ = nullptr; |
402 base_channel_ = socket.Pass(); | 419 base_channel_ = socket.Pass(); |
403 | 420 |
404 if (base_channel_.get()) { | 421 if (base_channel_.get()) { |
405 // Initialize reader and writer. | 422 // Initialize reader and writer. |
406 reader_.StartReading(base_channel_.get()); | 423 reader_.StartReading(base_channel_.get(), |
424 base::Bind(&ChannelMultiplexer::OnBaseChannelError, | |
425 base::Unretained(this))); | |
407 writer_.Init(base_channel_.get(), | 426 writer_.Init(base_channel_.get(), |
408 base::Bind(&ChannelMultiplexer::OnWriteFailed, | 427 base::Bind(&ChannelMultiplexer::OnBaseChannelError, |
409 base::Unretained(this))); | 428 base::Unretained(this))); |
410 } | 429 } |
411 | 430 |
412 DoCreatePendingChannels(); | 431 DoCreatePendingChannels(); |
413 } | 432 } |
414 | 433 |
415 void ChannelMultiplexer::DoCreatePendingChannels() { | 434 void ChannelMultiplexer::DoCreatePendingChannels() { |
416 if (pending_channels_.empty()) | 435 if (pending_channels_.empty()) |
417 return; | 436 return; |
418 | 437 |
(...skipping 21 matching lines...) Expand all Loading... | |
440 return it->second; | 459 return it->second; |
441 | 460 |
442 // Create a new channel if we haven't found existing one. | 461 // Create a new channel if we haven't found existing one. |
443 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); | 462 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); |
444 ++next_channel_id_; | 463 ++next_channel_id_; |
445 channels_[channel->name()] = channel; | 464 channels_[channel->name()] = channel; |
446 return channel; | 465 return channel; |
447 } | 466 } |
448 | 467 |
449 | 468 |
450 void ChannelMultiplexer::OnWriteFailed(int error) { | 469 void ChannelMultiplexer::OnBaseChannelError(int error) { |
451 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); | 470 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); |
452 it != channels_.end(); ++it) { | 471 it != channels_.end(); ++it) { |
453 base::ThreadTaskRunnerHandle::Get()->PostTask( | 472 base::ThreadTaskRunnerHandle::Get()->PostTask( |
454 FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed, | 473 FROM_HERE, |
455 weak_factory_.GetWeakPtr(), it->second->name())); | 474 base::Bind(&ChannelMultiplexer::NotifyBaseChannelError, |
475 weak_factory_.GetWeakPtr(), it->second->name(), error)); | |
456 } | 476 } |
457 } | 477 } |
458 | 478 |
459 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) { | 479 void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name, |
480 int error) { | |
460 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); | 481 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); |
461 if (it != channels_.end()) { | 482 if (it != channels_.end()) |
462 it->second->OnWriteFailed(); | 483 it->second->OnBaseChannelError(error); |
463 } | |
464 } | 484 } |
465 | 485 |
466 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, | 486 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, |
467 const base::Closure& done_task) { | 487 const base::Closure& done_task) { |
468 DCHECK(packet->has_channel_id()); | 488 DCHECK(packet->has_channel_id()); |
469 if (!packet->has_channel_id()) { | 489 if (!packet->has_channel_id()) { |
470 LOG(ERROR) << "Received packet without channel_id."; | 490 LOG(ERROR) << "Received packet without channel_id."; |
471 done_task.Run(); | 491 done_task.Run(); |
472 return; | 492 return; |
473 } | 493 } |
(...skipping 20 matching lines...) Expand all Loading... | |
494 channel->OnIncomingPacket(packet.Pass(), done_task); | 514 channel->OnIncomingPacket(packet.Pass(), done_task); |
495 } | 515 } |
496 | 516 |
497 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, | 517 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, |
498 const base::Closure& done_task) { | 518 const base::Closure& done_task) { |
499 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); | 519 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); |
500 } | 520 } |
501 | 521 |
502 } // namespace protocol | 522 } // namespace protocol |
503 } // namespace remoting | 523 } // namespace remoting |
OLD | NEW |