Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(111)

Side by Side Diff: remoting/protocol/channel_multiplexer.cc

Issue 1143443003: Fix MessageReader to pass errors to the channel (Closed) Base URL: https://chromium.googlesource.com/chromium/src.git@master
Patch Set: Created 5 years, 7 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be 2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file. 3 // found in the LICENSE file.
4 4
5 #include "remoting/protocol/channel_multiplexer.h" 5 #include "remoting/protocol/channel_multiplexer.h"
6 6
7 #include <string.h> 7 #include <string.h>
8 8
9 #include "base/bind.h" 9 #include "base/bind.h"
10 #include "base/callback.h" 10 #include "base/callback.h"
11 #include "base/callback_helpers.h"
11 #include "base/location.h" 12 #include "base/location.h"
12 #include "base/single_thread_task_runner.h" 13 #include "base/single_thread_task_runner.h"
13 #include "base/stl_util.h" 14 #include "base/stl_util.h"
14 #include "base/thread_task_runner_handle.h" 15 #include "base/thread_task_runner_handle.h"
15 #include "net/base/net_errors.h" 16 #include "net/base/net_errors.h"
16 #include "net/socket/stream_socket.h" 17 #include "net/socket/stream_socket.h"
17 #include "remoting/protocol/message_serialization.h" 18 #include "remoting/protocol/message_serialization.h"
18 19
19 namespace remoting { 20 namespace remoting {
20 namespace protocol { 21 namespace protocol {
(...skipping 51 matching lines...) Expand 10 before | Expand all | Expand 10 after
72 ~MuxChannel(); 73 ~MuxChannel();
73 74
74 const std::string& name() { return name_; } 75 const std::string& name() { return name_; }
75 int receive_id() { return receive_id_; } 76 int receive_id() { return receive_id_; }
76 void set_receive_id(int id) { receive_id_ = id; } 77 void set_receive_id(int id) { receive_id_ = id; }
77 78
78 // Called by ChannelMultiplexer. 79 // Called by ChannelMultiplexer.
79 scoped_ptr<net::StreamSocket> CreateSocket(); 80 scoped_ptr<net::StreamSocket> CreateSocket();
80 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, 81 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
81 const base::Closure& done_task); 82 const base::Closure& done_task);
82 void OnWriteFailed(); 83 void OnBaseChannelError(int error);
83 84
84 // Called by MuxSocket. 85 // Called by MuxSocket.
85 void OnSocketDestroyed(); 86 void OnSocketDestroyed();
86 bool DoWrite(scoped_ptr<MultiplexPacket> packet, 87 bool DoWrite(scoped_ptr<MultiplexPacket> packet,
87 const base::Closure& done_task); 88 const base::Closure& done_task);
88 int DoRead(net::IOBuffer* buffer, int buffer_len); 89 int DoRead(net::IOBuffer* buffer, int buffer_len);
89 90
90 private: 91 private:
91 ChannelMultiplexer* multiplexer_; 92 ChannelMultiplexer* multiplexer_;
92 std::string name_; 93 std::string name_;
93 int send_id_; 94 int send_id_;
94 bool id_sent_; 95 bool id_sent_;
95 int receive_id_; 96 int receive_id_;
96 MuxSocket* socket_; 97 MuxSocket* socket_;
97 std::list<PendingPacket*> pending_packets_; 98 std::list<PendingPacket*> pending_packets_;
98 99
99 DISALLOW_COPY_AND_ASSIGN(MuxChannel); 100 DISALLOW_COPY_AND_ASSIGN(MuxChannel);
100 }; 101 };
101 102
102 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, 103 class ChannelMultiplexer::MuxSocket : public net::StreamSocket,
103 public base::NonThreadSafe, 104 public base::NonThreadSafe,
104 public base::SupportsWeakPtr<MuxSocket> { 105 public base::SupportsWeakPtr<MuxSocket> {
105 public: 106 public:
106 MuxSocket(MuxChannel* channel); 107 MuxSocket(MuxChannel* channel);
107 ~MuxSocket() override; 108 ~MuxSocket() override;
108 109
109 void OnWriteComplete(); 110 void OnWriteComplete();
110 void OnWriteFailed(); 111 void OnBaseChannelError(int error);
111 void OnPacketReceived(); 112 void OnPacketReceived();
112 113
113 // net::StreamSocket interface. 114 // net::StreamSocket interface.
114 int Read(net::IOBuffer* buffer, 115 int Read(net::IOBuffer* buffer,
115 int buffer_len, 116 int buffer_len,
116 const net::CompletionCallback& callback) override; 117 const net::CompletionCallback& callback) override;
117 int Write(net::IOBuffer* buffer, 118 int Write(net::IOBuffer* buffer,
118 int buffer_len, 119 int buffer_len,
119 const net::CompletionCallback& callback) override; 120 const net::CompletionCallback& callback) override;
120 121
(...skipping 40 matching lines...) Expand 10 before | Expand all | Expand 10 after
161 return net::kProtoUnknown; 162 return net::kProtoUnknown;
162 } 163 }
163 bool GetSSLInfo(net::SSLInfo* ssl_info) override { 164 bool GetSSLInfo(net::SSLInfo* ssl_info) override {
164 NOTIMPLEMENTED(); 165 NOTIMPLEMENTED();
165 return false; 166 return false;
166 } 167 }
167 168
168 private: 169 private:
169 MuxChannel* channel_; 170 MuxChannel* channel_;
170 171
172 int base_channel_error_ = net::OK;
173
171 net::CompletionCallback read_callback_; 174 net::CompletionCallback read_callback_;
172 scoped_refptr<net::IOBuffer> read_buffer_; 175 scoped_refptr<net::IOBuffer> read_buffer_;
173 int read_buffer_size_; 176 int read_buffer_size_;
174 177
175 bool write_pending_; 178 bool write_pending_;
176 int write_result_; 179 int write_result_;
177 net::CompletionCallback write_callback_; 180 net::CompletionCallback write_callback_;
178 181
179 net::BoundNetLog net_log_; 182 net::BoundNetLog net_log_;
180 183
(...skipping 32 matching lines...) Expand 10 before | Expand all | Expand 10 after
213 DCHECK_EQ(packet->channel_id(), receive_id_); 216 DCHECK_EQ(packet->channel_id(), receive_id_);
214 if (packet->data().size() > 0) { 217 if (packet->data().size() > 0) {
215 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); 218 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task));
216 if (socket_) { 219 if (socket_) {
217 // Notify the socket that we have more data. 220 // Notify the socket that we have more data.
218 socket_->OnPacketReceived(); 221 socket_->OnPacketReceived();
219 } 222 }
220 } 223 }
221 } 224 }
222 225
223 void ChannelMultiplexer::MuxChannel::OnWriteFailed() { 226 void ChannelMultiplexer::MuxChannel::OnBaseChannelError(int error) {
224 if (socket_) 227 if (socket_)
225 socket_->OnWriteFailed(); 228 socket_->OnBaseChannelError(error);
226 } 229 }
227 230
228 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { 231 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() {
229 DCHECK(socket_); 232 DCHECK(socket_);
230 socket_ = nullptr; 233 socket_ = nullptr;
231 } 234 }
232 235
233 bool ChannelMultiplexer::MuxChannel::DoWrite( 236 bool ChannelMultiplexer::MuxChannel::DoWrite(
234 scoped_ptr<MultiplexPacket> packet, 237 scoped_ptr<MultiplexPacket> packet,
235 const base::Closure& done_task) { 238 const base::Closure& done_task) {
(...skipping 33 matching lines...) Expand 10 before | Expand all | Expand 10 after
269 ChannelMultiplexer::MuxSocket::~MuxSocket() { 272 ChannelMultiplexer::MuxSocket::~MuxSocket() {
270 channel_->OnSocketDestroyed(); 273 channel_->OnSocketDestroyed();
271 } 274 }
272 275
273 int ChannelMultiplexer::MuxSocket::Read( 276 int ChannelMultiplexer::MuxSocket::Read(
274 net::IOBuffer* buffer, int buffer_len, 277 net::IOBuffer* buffer, int buffer_len,
275 const net::CompletionCallback& callback) { 278 const net::CompletionCallback& callback) {
276 DCHECK(CalledOnValidThread()); 279 DCHECK(CalledOnValidThread());
277 DCHECK(read_callback_.is_null()); 280 DCHECK(read_callback_.is_null());
278 281
282 if (base_channel_error_ != net::OK)
283 return base_channel_error_;
284
279 int result = channel_->DoRead(buffer, buffer_len); 285 int result = channel_->DoRead(buffer, buffer_len);
280 if (result == 0) { 286 if (result == 0) {
281 read_buffer_ = buffer; 287 read_buffer_ = buffer;
282 read_buffer_size_ = buffer_len; 288 read_buffer_size_ = buffer_len;
283 read_callback_ = callback; 289 read_callback_ = callback;
284 return net::ERR_IO_PENDING; 290 return net::ERR_IO_PENDING;
285 } 291 }
286 return result; 292 return result;
287 } 293 }
288 294
289 int ChannelMultiplexer::MuxSocket::Write( 295 int ChannelMultiplexer::MuxSocket::Write(
290 net::IOBuffer* buffer, int buffer_len, 296 net::IOBuffer* buffer, int buffer_len,
291 const net::CompletionCallback& callback) { 297 const net::CompletionCallback& callback) {
292 DCHECK(CalledOnValidThread()); 298 DCHECK(CalledOnValidThread());
299 DCHECK(write_callback_.is_null());
300
301 if (base_channel_error_ != net::OK)
302 return base_channel_error_;
293 303
294 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); 304 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket());
295 size_t size = std::min(kMaxPacketSize, buffer_len); 305 size_t size = std::min(kMaxPacketSize, buffer_len);
296 packet->mutable_data()->assign(buffer->data(), size); 306 packet->mutable_data()->assign(buffer->data(), size);
297 307
298 write_pending_ = true; 308 write_pending_ = true;
299 bool result = channel_->DoWrite(packet.Pass(), base::Bind( 309 bool result = channel_->DoWrite(packet.Pass(), base::Bind(
300 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); 310 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr()));
301 311
302 if (!result) { 312 if (!result) {
303 // Cannot complete the write, e.g. if the connection has been terminated. 313 // Cannot complete the write, e.g. if the connection has been terminated.
304 return net::ERR_FAILED; 314 return net::ERR_FAILED;
305 } 315 }
306 316
307 // OnWriteComplete() might be called above synchronously. 317 // OnWriteComplete() might be called above synchronously.
308 if (write_pending_) { 318 if (write_pending_) {
309 DCHECK(write_callback_.is_null()); 319 DCHECK(write_callback_.is_null());
310 write_callback_ = callback; 320 write_callback_ = callback;
311 write_result_ = size; 321 write_result_ = size;
312 return net::ERR_IO_PENDING; 322 return net::ERR_IO_PENDING;
313 } 323 }
314 324
315 return size; 325 return size;
316 } 326 }
317 327
318 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { 328 void ChannelMultiplexer::MuxSocket::OnWriteComplete() {
319 write_pending_ = false; 329 write_pending_ = false;
320 if (!write_callback_.is_null()) { 330 if (!write_callback_.is_null())
321 net::CompletionCallback cb; 331 base::ResetAndReturn(&write_callback_).Run(write_result_);
322 std::swap(cb, write_callback_); 332
323 cb.Run(write_result_);
324 }
325 } 333 }
326 334
327 void ChannelMultiplexer::MuxSocket::OnWriteFailed() { 335 void ChannelMultiplexer::MuxSocket::OnBaseChannelError(int error) {
328 if (!write_callback_.is_null()) { 336 base_channel_error_ = error;
329 net::CompletionCallback cb; 337
330 std::swap(cb, write_callback_); 338 // Here only one of the read and write callbacks is called if both of them are
331 cb.Run(net::ERR_FAILED); 339 // pending. Ideally both of them should be called in that case, but that would
340 // require the second one to be called asynchronously which would complicate
341 // this code. Channels handle read and write errors the same way (see
342 // ChannelDispatcherBase::OnReadWriteFailed) so calling only one of the
343 // callbacks is enough.
Jamie 2015/05/13 18:25:17 But you're only resetting one of the callbacks in
Sergey Ulanov 2015/05/14 00:15:21 It's better to leave write_callback_ not null in t
344
345 if (!read_callback_.is_null()) {
346 base::ResetAndReturn(&read_callback_).Run(error);
347 return;
332 } 348 }
349
350 if (!write_callback_.is_null())
351 base::ResetAndReturn(&write_callback_).Run(error);
333 } 352 }
334 353
335 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { 354 void ChannelMultiplexer::MuxSocket::OnPacketReceived() {
336 if (!read_callback_.is_null()) { 355 if (!read_callback_.is_null()) {
337 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_); 356 int result = channel_->DoRead(read_buffer_.get(), read_buffer_size_);
338 read_buffer_ = nullptr; 357 read_buffer_ = nullptr;
339 DCHECK_GT(result, 0); 358 DCHECK_GT(result, 0);
340 net::CompletionCallback cb; 359 base::ResetAndReturn(&read_callback_).Run(result);
341 std::swap(cb, read_callback_);
342 cb.Run(result);
343 } 360 }
344 } 361 }
345 362
346 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory, 363 ChannelMultiplexer::ChannelMultiplexer(StreamChannelFactory* factory,
347 const std::string& base_channel_name) 364 const std::string& base_channel_name)
348 : base_channel_factory_(factory), 365 : base_channel_factory_(factory),
349 base_channel_name_(base_channel_name), 366 base_channel_name_(base_channel_name),
350 next_channel_id_(0), 367 next_channel_id_(0),
351 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket, 368 parser_(base::Bind(&ChannelMultiplexer::OnIncomingPacket,
352 base::Unretained(this)), 369 base::Unretained(this)),
(...skipping 43 matching lines...) Expand 10 before | Expand all | Expand 10 after
396 } 413 }
397 } 414 }
398 415
399 void ChannelMultiplexer::OnBaseChannelReady( 416 void ChannelMultiplexer::OnBaseChannelReady(
400 scoped_ptr<net::StreamSocket> socket) { 417 scoped_ptr<net::StreamSocket> socket) {
401 base_channel_factory_ = nullptr; 418 base_channel_factory_ = nullptr;
402 base_channel_ = socket.Pass(); 419 base_channel_ = socket.Pass();
403 420
404 if (base_channel_.get()) { 421 if (base_channel_.get()) {
405 // Initialize reader and writer. 422 // Initialize reader and writer.
406 reader_.StartReading(base_channel_.get()); 423 reader_.StartReading(base_channel_.get(),
424 base::Bind(&ChannelMultiplexer::OnBaseChannelError,
425 base::Unretained(this)));
407 writer_.Init(base_channel_.get(), 426 writer_.Init(base_channel_.get(),
408 base::Bind(&ChannelMultiplexer::OnWriteFailed, 427 base::Bind(&ChannelMultiplexer::OnBaseChannelError,
409 base::Unretained(this))); 428 base::Unretained(this)));
410 } 429 }
411 430
412 DoCreatePendingChannels(); 431 DoCreatePendingChannels();
413 } 432 }
414 433
415 void ChannelMultiplexer::DoCreatePendingChannels() { 434 void ChannelMultiplexer::DoCreatePendingChannels() {
416 if (pending_channels_.empty()) 435 if (pending_channels_.empty())
417 return; 436 return;
418 437
(...skipping 21 matching lines...) Expand all
440 return it->second; 459 return it->second;
441 460
442 // Create a new channel if we haven't found existing one. 461 // Create a new channel if we haven't found existing one.
443 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); 462 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_);
444 ++next_channel_id_; 463 ++next_channel_id_;
445 channels_[channel->name()] = channel; 464 channels_[channel->name()] = channel;
446 return channel; 465 return channel;
447 } 466 }
448 467
449 468
450 void ChannelMultiplexer::OnWriteFailed(int error) { 469 void ChannelMultiplexer::OnBaseChannelError(int error) {
451 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); 470 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin();
452 it != channels_.end(); ++it) { 471 it != channels_.end(); ++it) {
453 base::ThreadTaskRunnerHandle::Get()->PostTask( 472 base::ThreadTaskRunnerHandle::Get()->PostTask(
454 FROM_HERE, base::Bind(&ChannelMultiplexer::NotifyWriteFailed, 473 FROM_HERE,
455 weak_factory_.GetWeakPtr(), it->second->name())); 474 base::Bind(&ChannelMultiplexer::NotifyBaseChannelError,
475 weak_factory_.GetWeakPtr(), it->second->name(), error));
456 } 476 }
457 } 477 }
458 478
459 void ChannelMultiplexer::NotifyWriteFailed(const std::string& name) { 479 void ChannelMultiplexer::NotifyBaseChannelError(const std::string& name,
480 int error) {
460 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); 481 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name);
461 if (it != channels_.end()) { 482 if (it != channels_.end())
462 it->second->OnWriteFailed(); 483 it->second->OnBaseChannelError(error);
463 }
464 } 484 }
465 485
466 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, 486 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet,
467 const base::Closure& done_task) { 487 const base::Closure& done_task) {
468 DCHECK(packet->has_channel_id()); 488 DCHECK(packet->has_channel_id());
469 if (!packet->has_channel_id()) { 489 if (!packet->has_channel_id()) {
470 LOG(ERROR) << "Received packet without channel_id."; 490 LOG(ERROR) << "Received packet without channel_id.";
471 done_task.Run(); 491 done_task.Run();
472 return; 492 return;
473 } 493 }
(...skipping 20 matching lines...) Expand all
494 channel->OnIncomingPacket(packet.Pass(), done_task); 514 channel->OnIncomingPacket(packet.Pass(), done_task);
495 } 515 }
496 516
497 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, 517 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet,
498 const base::Closure& done_task) { 518 const base::Closure& done_task) {
499 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); 519 return writer_.Write(SerializeAndFrameMessage(*packet), done_task);
500 } 520 }
501 521
502 } // namespace protocol 522 } // namespace protocol
503 } // namespace remoting 523 } // namespace remoting
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698