OLD | NEW |
---|---|
(Empty) | |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "remoting/protocol/channel_multiplexer.h" | |
6 | |
7 #include <string.h> | |
8 | |
9 #include "base/bind.h" | |
10 #include "base/callback.h" | |
11 #include "base/location.h" | |
12 #include "base/stl_util.h" | |
13 #include "net/base/net_errors.h" | |
14 #include "net/socket/stream_socket.h" | |
15 #include "remoting/protocol/util.h" | |
16 | |
17 namespace remoting { | |
18 namespace protocol { | |
19 | |
20 namespace { | |
21 const int kChannelIdUnknown = -1; | |
22 const int kMaxPacketSize = 1024; | |
23 | |
24 class PendingPacket { | |
25 public: | |
26 PendingPacket(scoped_ptr<MultiplexPacket> packet, | |
27 const base::Closure& done_task) | |
28 : packet(packet.Pass()), | |
29 done_task(done_task), | |
30 pos(0U) { | |
31 } | |
32 ~PendingPacket() { | |
33 done_task.Run(); | |
34 } | |
35 | |
36 bool is_empty() { return pos >= packet->data().size(); } | |
37 | |
38 int Read(char* buffer, size_t size) { | |
39 size = std::min(size, packet->data().size() - pos); | |
40 memcpy(buffer, packet->data().data() + pos, size); | |
41 pos += size; | |
42 return size; | |
43 } | |
44 | |
45 private: | |
46 scoped_ptr<MultiplexPacket> packet; | |
47 base::Closure done_task; | |
48 size_t pos; | |
49 | |
50 DISALLOW_COPY_AND_ASSIGN(PendingPacket); | |
51 }; | |
52 | |
53 } // namespace | |
54 | |
55 const char ChannelMultiplexer::kMuxChannelName[] = "mux"; | |
56 | |
57 struct ChannelMultiplexer::PendingChannel { | |
58 PendingChannel(const std::string& name, | |
59 const StreamChannelCallback& callback) | |
60 : name(name), callback(callback) { | |
61 } | |
62 std::string name; | |
63 StreamChannelCallback callback; | |
64 }; | |
65 | |
66 class ChannelMultiplexer::MuxChannel { | |
67 public: | |
68 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name, | |
69 int send_id); | |
70 ~MuxChannel(); | |
71 | |
72 const std::string& name() { return name_; } | |
73 int receive_id() { return receive_id_; } | |
74 void set_receive_id(int id) { receive_id_ = id; } | |
75 | |
76 // Called by ChannelMultiplexer. | |
77 scoped_ptr<net::StreamSocket> CreateSocket(); | |
78 void OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, | |
79 const base::Closure& done_task); | |
80 void OnWriteFailed(); | |
81 | |
82 // Called by MuxSocket. | |
83 void OnSocketDestroyed(); | |
84 bool DoWrite(scoped_ptr<MultiplexPacket> packet, | |
85 const base::Closure& done_task); | |
86 int DoRead(net::IOBuffer* buffer, int buffer_len); | |
87 | |
88 private: | |
89 ChannelMultiplexer* multiplexer_; | |
90 std::string name_; | |
91 int send_id_; | |
92 bool id_sent_; | |
93 int receive_id_; | |
94 MuxSocket* socket_; | |
95 std::list<PendingPacket*> pending_packets_; | |
96 | |
97 DISALLOW_COPY_AND_ASSIGN(MuxChannel); | |
98 }; | |
99 | |
100 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, | |
101 public base::NonThreadSafe, | |
102 public base::SupportsWeakPtr<MuxSocket> { | |
103 public: | |
104 MuxSocket(MuxChannel* channel); | |
105 ~MuxSocket(); | |
106 | |
107 void OnWriteComplete(); | |
108 void OnWriteFailed(); | |
109 void OnPacketReceived(); | |
110 | |
111 // net::StreamSocket interface. | |
112 virtual int Read(net::IOBuffer* buffer, int buffer_len, | |
113 const net::CompletionCallback& callback) OVERRIDE; | |
114 virtual int Write(net::IOBuffer* buffer, int buffer_len, | |
115 const net::CompletionCallback& callback) OVERRIDE; | |
116 | |
117 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { | |
118 NOTIMPLEMENTED(); | |
119 return false; | |
120 } | |
121 virtual bool SetSendBufferSize(int32 size) OVERRIDE { | |
122 NOTIMPLEMENTED(); | |
123 return false; | |
124 } | |
125 | |
126 virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { | |
127 NOTIMPLEMENTED(); | |
128 return net::ERR_FAILED; | |
129 } | |
130 virtual void Disconnect() OVERRIDE { | |
131 NOTIMPLEMENTED(); | |
132 } | |
133 virtual bool IsConnected() const OVERRIDE { | |
134 NOTIMPLEMENTED(); | |
135 return true; | |
136 } | |
137 virtual bool IsConnectedAndIdle() const OVERRIDE { | |
138 NOTIMPLEMENTED(); | |
139 return false; | |
140 } | |
141 virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { | |
142 NOTIMPLEMENTED(); | |
143 return net::ERR_FAILED; | |
144 } | |
145 virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { | |
146 NOTIMPLEMENTED(); | |
147 return net::ERR_FAILED; | |
148 } | |
149 virtual const net::BoundNetLog& NetLog() const OVERRIDE { | |
150 NOTIMPLEMENTED(); | |
151 return net_log_; | |
152 } | |
153 virtual void SetSubresourceSpeculation() OVERRIDE { | |
154 NOTIMPLEMENTED(); | |
155 } | |
156 virtual void SetOmniboxSpeculation() OVERRIDE { | |
157 NOTIMPLEMENTED(); | |
158 } | |
159 virtual bool WasEverUsed() const OVERRIDE { | |
160 return true; | |
161 } | |
162 virtual bool UsingTCPFastOpen() const OVERRIDE { | |
163 return false; | |
164 } | |
165 virtual int64 NumBytesRead() const OVERRIDE { | |
166 NOTIMPLEMENTED(); | |
167 return 0; | |
168 } | |
169 virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { | |
170 NOTIMPLEMENTED(); | |
171 return base::TimeDelta(); | |
172 } | |
173 virtual bool WasNpnNegotiated() const OVERRIDE { | |
174 return false; | |
175 } | |
176 virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { | |
177 return net::kProtoUnknown; | |
178 } | |
179 virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { | |
180 NOTIMPLEMENTED(); | |
181 return false; | |
182 } | |
183 | |
184 private: | |
185 MuxChannel* channel_; | |
186 | |
187 net::CompletionCallback read_callback_; | |
188 scoped_refptr<net::IOBuffer> read_buffer_; | |
189 int read_buffer_size_; | |
190 | |
191 bool write_pending_; | |
192 int write_result_; | |
193 net::CompletionCallback write_callback_; | |
194 | |
195 net::BoundNetLog net_log_; | |
196 | |
197 DISALLOW_COPY_AND_ASSIGN(MuxSocket); | |
198 }; | |
199 | |
200 | |
201 ChannelMultiplexer::MuxChannel::MuxChannel( | |
202 ChannelMultiplexer* multiplexer, | |
203 const std::string& name, | |
204 int send_id) | |
205 : multiplexer_(multiplexer), | |
206 name_(name), | |
207 send_id_(send_id), | |
208 id_sent_(false), | |
209 receive_id_(kChannelIdUnknown), | |
210 socket_(NULL) { | |
211 } | |
212 | |
213 ChannelMultiplexer::MuxChannel::~MuxChannel() { | |
214 // Socket must be destroyed before the channel. | |
215 DCHECK(!socket_); | |
216 STLDeleteElements(&pending_packets_); | |
217 } | |
218 | |
219 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() { | |
220 DCHECK(!socket_); // Can't create more than one socket per channel. | |
221 scoped_ptr<MuxSocket> result(new MuxSocket(this)); | |
222 socket_ = result.get(); | |
223 return result.PassAs<net::StreamSocket>(); | |
224 } | |
225 | |
226 void ChannelMultiplexer::MuxChannel::OnIncomingPacket( | |
227 scoped_ptr<MultiplexPacket> packet, | |
228 const base::Closure& done_task) { | |
229 DCHECK_EQ(packet->channel_id(), receive_id_); | |
230 if (packet->data().size() > 0) { | |
231 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); | |
232 if (socket_) { | |
233 // Notify the socket that we have more data. | |
234 socket_->OnPacketReceived(); | |
235 } | |
236 } | |
237 } | |
238 | |
239 void ChannelMultiplexer::MuxChannel::OnWriteFailed() { | |
240 if (socket_) | |
241 socket_->OnWriteFailed(); | |
242 } | |
243 | |
244 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { | |
245 DCHECK(socket_); | |
246 socket_ = NULL; | |
247 } | |
248 | |
249 bool ChannelMultiplexer::MuxChannel::DoWrite( | |
250 scoped_ptr<MultiplexPacket> packet, | |
251 const base::Closure& done_task) { | |
252 packet->set_channel_id(send_id_); | |
253 if (!id_sent_) { | |
254 packet->set_channel_name(name_); | |
255 id_sent_ = true; | |
256 } | |
257 return multiplexer_->DoWrite(packet.Pass(), done_task); | |
258 } | |
259 | |
260 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer, | |
261 int buffer_len) { | |
262 int pos = 0; | |
263 while (buffer_len > 0 && !pending_packets_.empty()) { | |
264 DCHECK(!pending_packets_.front()->is_empty()); | |
265 int result = pending_packets_.front()->Read( | |
266 buffer->data() + pos, buffer_len); | |
267 DCHECK_LE(result, buffer_len); | |
268 pos += result; | |
269 buffer_len -= pos; | |
270 if (pending_packets_.front()->is_empty()) { | |
271 delete pending_packets_.front(); | |
272 pending_packets_.erase(pending_packets_.begin()); | |
273 } | |
274 } | |
275 return pos; | |
276 } | |
277 | |
278 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel) | |
279 : channel_(channel), | |
280 read_buffer_size_(0), | |
281 write_pending_(false), | |
282 write_result_(0) { | |
283 } | |
284 | |
285 ChannelMultiplexer::MuxSocket::~MuxSocket() { | |
286 channel_->OnSocketDestroyed(); | |
287 } | |
288 | |
289 int ChannelMultiplexer::MuxSocket::Read( | |
290 net::IOBuffer* buffer, int buffer_len, | |
291 const net::CompletionCallback& callback) { | |
292 DCHECK(CalledOnValidThread()); | |
293 DCHECK(read_callback_.is_null()); | |
294 | |
295 int result = channel_->DoRead(buffer, buffer_len); | |
296 if (result == 0) { | |
297 read_buffer_ = buffer; | |
298 read_buffer_size_ = buffer_len; | |
299 read_callback_ = callback; | |
300 return net::ERR_IO_PENDING; | |
301 } | |
302 return result; | |
303 } | |
304 | |
305 int ChannelMultiplexer::MuxSocket::Write( | |
306 net::IOBuffer* buffer, int buffer_len, | |
307 const net::CompletionCallback& callback) { | |
308 DCHECK(CalledOnValidThread()); | |
309 | |
310 scoped_ptr<MultiplexPacket> packet(new MultiplexPacket()); | |
311 size_t size = std::min(kMaxPacketSize, buffer_len); | |
312 packet->mutable_data()->assign(buffer->data(), size); | |
313 | |
314 write_pending_ = true; | |
315 bool result = channel_->DoWrite(packet.Pass(), base::Bind( | |
316 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); | |
317 | |
318 if (!result) { | |
319 // Cannot complete the write, e.g. if the connection has been terminated. | |
320 return net::ERR_FAILED; | |
321 } | |
322 | |
323 // OnWriteComplete() might be called above synchronously. | |
324 if (write_pending_) { | |
325 DCHECK(write_callback_.is_null()); | |
326 write_callback_ = callback; | |
327 write_result_ = size; | |
328 return net::ERR_IO_PENDING; | |
329 } | |
330 | |
331 return size; | |
332 } | |
333 | |
334 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { | |
335 write_pending_ = false; | |
336 if (!write_callback_.is_null()) { | |
337 net::CompletionCallback cb; | |
338 std::swap(cb, write_callback_); | |
339 cb.Run(write_result_); | |
340 } | |
341 } | |
342 | |
343 void ChannelMultiplexer::MuxSocket::OnWriteFailed() { | |
344 if (!write_callback_.is_null()) { | |
345 net::CompletionCallback cb; | |
346 std::swap(cb, write_callback_); | |
347 cb.Run(net::ERR_FAILED); | |
348 } | |
349 } | |
350 | |
351 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { | |
352 if (!read_callback_.is_null()) { | |
353 int result = channel_->DoRead(read_buffer_, read_buffer_size_); | |
354 read_buffer_ = NULL; | |
355 DCHECK_GT(result, 0); | |
356 net::CompletionCallback cb; | |
357 std::swap(cb, read_callback_); | |
358 cb.Run(result); | |
359 } | |
360 } | |
361 | |
362 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory, | |
363 const std::string& base_channel_name) | |
364 : base_channel_factory_(factory), | |
365 base_channel_name_(base_channel_name), | |
366 next_channel_id_(0), | |
367 destroyed_flag_(NULL) { | |
368 factory->CreateStreamChannel( | |
369 base_channel_name, | |
370 base::Bind(&ChannelMultiplexer::OnBaseChannelReady, | |
371 base::Unretained(this))); | |
372 } | |
373 | |
374 ChannelMultiplexer::~ChannelMultiplexer() { | |
375 DCHECK(pending_channels_.empty()); | |
376 STLDeleteValues(&channels_); | |
377 | |
378 // Cancel creation of the base channel if it hasn't finished. | |
379 if (base_channel_factory_) | |
380 base_channel_factory_->CancelChannelCreation(base_channel_name_); | |
381 | |
382 if (destroyed_flag_) | |
383 *destroyed_flag_ = true; | |
384 } | |
385 | |
386 void ChannelMultiplexer::CreateStreamChannel( | |
387 const std::string& name, | |
388 const StreamChannelCallback& callback) { | |
389 if (base_channel_.get()) { | |
390 // Already have |base_channel_|. Create new multiplexed channel | |
391 // synchronously. | |
392 callback.Run(GetOrCreateChannel(name)->CreateSocket()); | |
393 } else if (!base_channel_.get() && !base_channel_factory_) { | |
394 // Fail synchronously if we failed to create |base_channel_|. | |
395 callback.Run(scoped_ptr<net::StreamSocket>()); | |
396 } else { | |
397 // Still waiting for the |base_channel_|. | |
398 pending_channels_.push_back(PendingChannel(name, callback)); | |
399 } | |
400 } | |
401 | |
402 void ChannelMultiplexer::CreateDatagramChannel( | |
403 const std::string& name, | |
404 const DatagramChannelCallback& callback) { | |
405 NOTIMPLEMENTED(); | |
406 callback.Run(scoped_ptr<net::Socket>()); | |
407 } | |
408 | |
409 void ChannelMultiplexer::CancelChannelCreation(const std::string& name) { | |
410 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | |
411 it != pending_channels_.end(); ++it) { | |
412 if (it->name == name) { | |
413 pending_channels_.erase(it); | |
414 return; | |
415 } | |
416 } | |
417 NOTREACHED(); | |
Wez
2012/08/07 21:44:45
Do you really want a NOTREACHED() here; it seems v
Sergey Ulanov
2012/08/07 22:25:14
Yes, you are right.
| |
418 } | |
419 | |
420 void ChannelMultiplexer::OnBaseChannelReady( | |
421 scoped_ptr<net::StreamSocket> socket) { | |
422 base_channel_factory_ = NULL; | |
423 base_channel_ = socket.Pass(); | |
424 | |
425 if (!base_channel_.get()) { | |
426 // Notify all callers that we can't create any channels. | |
427 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | |
428 it != pending_channels_.end(); ++it) { | |
429 it->callback.Run(scoped_ptr<net::StreamSocket>()); | |
430 } | |
431 pending_channels_.clear(); | |
432 return; | |
433 } | |
434 | |
435 // Initialize reader and writer. | |
436 reader_.Init(base_channel_.get(), | |
437 base::Bind(&ChannelMultiplexer::OnIncomingPacket, | |
438 base::Unretained(this))); | |
439 writer_.Init(base_channel_.get(), | |
440 base::Bind(&ChannelMultiplexer::OnWriteFailed, | |
441 base::Unretained(this))); | |
442 | |
443 // Now create all pending channels. | |
444 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | |
445 it != pending_channels_.end(); ++it) { | |
446 it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket()); | |
447 } | |
448 pending_channels_.clear(); | |
449 } | |
450 | |
451 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( | |
452 const std::string& name) { | |
453 // Check if we already have a channel with the requested name. | |
454 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); | |
455 if (it != channels_.end()) | |
456 return it->second; | |
457 | |
458 // Create a new channel if we haven't found existing one. | |
459 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); | |
460 ++next_channel_id_; | |
461 channels_[channel->name()] = channel; | |
462 return channel; | |
463 } | |
464 | |
465 | |
466 void ChannelMultiplexer::OnWriteFailed(int error) { | |
467 bool destroyed = false; | |
468 destroyed_flag_ = &destroyed; | |
469 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); | |
470 it != channels_.end(); ++it) { | |
471 it->second->OnWriteFailed(); | |
472 if (destroyed) | |
473 return; | |
474 } | |
475 destroyed_flag_ = NULL; | |
476 } | |
477 | |
478 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MultiplexPacket> packet, | |
479 const base::Closure& done_task) { | |
480 if (!packet->has_channel_id()) { | |
481 LOG(ERROR) << "Received packet without channel_id."; | |
482 done_task.Run(); | |
483 return; | |
484 } | |
485 | |
486 int receive_id = packet->channel_id(); | |
487 MuxChannel* channel = NULL; | |
488 std::map<int, MuxChannel*>::iterator it = | |
489 channels_by_receive_id_.find(receive_id); | |
490 if (it != channels_by_receive_id_.end()) { | |
491 channel = it->second; | |
492 } else { | |
493 // This is a new |channel_id| we haven't seen before. Look it up by name. | |
494 if (!packet->has_channel_name()) { | |
495 LOG(ERROR) << "Received packet with unknown channel_id and " | |
496 "without channel_name."; | |
497 done_task.Run(); | |
498 return; | |
499 } | |
500 channel = GetOrCreateChannel(packet->channel_name()); | |
501 channel->set_receive_id(receive_id); | |
502 channels_by_receive_id_[receive_id] = channel; | |
503 } | |
504 | |
505 channel->OnIncomingPacket(packet.Pass(), done_task); | |
506 } | |
507 | |
508 bool ChannelMultiplexer::DoWrite(scoped_ptr<MultiplexPacket> packet, | |
509 const base::Closure& done_task) { | |
510 return writer_.Write(SerializeAndFrameMessage(*packet), done_task); | |
511 } | |
512 | |
513 } // namespace protocol | |
514 } // namespace remoting | |
OLD | NEW |