OLD | NEW |
---|---|
(Empty) | |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "remoting/protocol/channel_multiplexer.h" | |
6 | |
7 #include <string.h> | |
8 | |
9 #include "base/bind.h" | |
10 #include "base/callback.h" | |
11 #include "base/location.h" | |
12 #include "base/stl_util.h" | |
13 #include "net/base/net_errors.h" | |
14 #include "net/socket/stream_socket.h" | |
15 #include "remoting/protocol/util.h" | |
16 | |
17 namespace remoting { | |
18 namespace protocol { | |
19 | |
20 namespace { | |
21 const int kChannelIdUnknown = -1; | |
22 const int kMaxPacketSize = 1024; | |
23 | |
24 class PendingPacket { | |
25 public: | |
26 PendingPacket(scoped_ptr<MuxPacket> packet, const base::Closure& done_task) | |
27 : packet(packet.Pass()), | |
28 done_task(done_task), | |
29 pos(0U) { | |
30 } | |
31 ~PendingPacket() { | |
32 done_task.Run(); | |
33 } | |
34 | |
35 bool is_empty() { return pos >= packet->data().size(); } | |
36 | |
37 int Read(char* buffer, size_t size) { | |
38 size = std::min(size, packet->data().size() - pos); | |
39 memcpy(buffer, packet->data().data() + pos, size); | |
40 pos += size; | |
41 return size; | |
42 } | |
43 | |
44 private: | |
45 scoped_ptr<MuxPacket> packet; | |
46 base::Closure done_task; | |
47 size_t pos; | |
48 | |
49 DISALLOW_COPY_AND_ASSIGN(PendingPacket); | |
50 }; | |
51 | |
52 } // namespace | |
53 | |
54 const char ChannelMultiplexer::kMuxChannelName[] = "mux"; | |
55 | |
56 struct ChannelMultiplexer::PendingChannel { | |
57 PendingChannel(const std::string& name, | |
58 const StreamChannelCallback& callback) | |
59 : name(name), callback(callback) { | |
60 } | |
61 std::string name; | |
62 StreamChannelCallback callback; | |
63 }; | |
64 | |
65 class ChannelMultiplexer::MuxChannel { | |
66 public: | |
67 MuxChannel(ChannelMultiplexer* multiplexer, const std::string& name, int id); | |
68 ~MuxChannel(); | |
69 | |
70 const std::string& name() { return name_; } | |
71 int id() { return id_; } | |
Wez
2012/08/06 23:14:07
nit: local_id?
Sergey Ulanov
2012/08/07 20:12:22
renamed to send id
| |
72 int remote_id() { return remote_id_; } | |
Wez
2012/08/06 23:14:07
nit: Or received_id() and send_id()?
Sergey Ulanov
2012/08/07 20:12:22
Done.
| |
73 void set_remote_id(int id) { remote_id_ = id; } | |
74 | |
75 // Called by ChannelMultiplexer. | |
76 scoped_ptr<net::StreamSocket> CreateSocket(); | |
77 void OnIncomingPacket(scoped_ptr<MuxPacket> packet, | |
78 const base::Closure& done_task); | |
79 void OnWriteFailed(); | |
80 | |
81 // Called by MuxSocket. | |
82 void OnSocketDestroyed(); | |
83 void DoWrite(scoped_ptr<MuxPacket> packet, const base::Closure& done_task); | |
84 int DoRead(net::IOBuffer* buffer, int buffer_len); | |
85 | |
86 private: | |
87 ChannelMultiplexer* multiplexer_; | |
88 std::string name_; | |
89 int id_; | |
90 bool id_sent_; | |
91 int remote_id_; | |
92 MuxSocket* socket_; | |
93 std::list<PendingPacket*> pending_packets_; | |
94 | |
95 DISALLOW_COPY_AND_ASSIGN(MuxChannel); | |
96 }; | |
97 | |
98 class ChannelMultiplexer::MuxSocket : public net::StreamSocket, | |
99 public base::NonThreadSafe, | |
100 public base::SupportsWeakPtr<MuxSocket> { | |
101 public: | |
102 MuxSocket(MuxChannel* channel); | |
103 ~MuxSocket(); | |
104 | |
105 void OnWriteComplete(); | |
106 void OnWriteFailed(); | |
107 void OnPacketReceived(); | |
108 | |
109 // net::StreamSocket interface. | |
110 virtual int Read(net::IOBuffer* buffer, int buffer_len, | |
111 const net::CompletionCallback& callback) OVERRIDE; | |
112 virtual int Write(net::IOBuffer* buffer, int buffer_len, | |
113 const net::CompletionCallback& callback) OVERRIDE; | |
114 | |
115 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { | |
116 NOTIMPLEMENTED(); | |
117 return false; | |
118 } | |
119 virtual bool SetSendBufferSize(int32 size) OVERRIDE { | |
120 NOTIMPLEMENTED(); | |
121 return false; | |
122 } | |
123 | |
124 virtual int Connect(const net::CompletionCallback& callback) OVERRIDE { | |
125 NOTIMPLEMENTED(); | |
126 return net::ERR_FAILED; | |
127 } | |
128 virtual void Disconnect() OVERRIDE { | |
129 NOTIMPLEMENTED(); | |
130 } | |
131 virtual bool IsConnected() const OVERRIDE { | |
132 NOTIMPLEMENTED(); | |
133 return true; | |
134 } | |
135 virtual bool IsConnectedAndIdle() const OVERRIDE { | |
136 NOTIMPLEMENTED(); | |
137 return false; | |
138 } | |
139 virtual int GetPeerAddress(net::IPEndPoint* address) const OVERRIDE { | |
140 NOTIMPLEMENTED(); | |
141 return net::ERR_FAILED; | |
142 } | |
143 virtual int GetLocalAddress(net::IPEndPoint* address) const OVERRIDE { | |
144 NOTIMPLEMENTED(); | |
145 return net::ERR_FAILED; | |
146 } | |
147 virtual const net::BoundNetLog& NetLog() const OVERRIDE { | |
148 NOTIMPLEMENTED(); | |
149 return net_log_; | |
150 } | |
151 virtual void SetSubresourceSpeculation() OVERRIDE { | |
152 NOTIMPLEMENTED(); | |
153 } | |
154 virtual void SetOmniboxSpeculation() OVERRIDE { | |
155 NOTIMPLEMENTED(); | |
156 } | |
157 virtual bool WasEverUsed() const OVERRIDE { | |
158 return true; | |
159 } | |
160 virtual bool UsingTCPFastOpen() const OVERRIDE { | |
161 return false; | |
162 } | |
163 virtual int64 NumBytesRead() const OVERRIDE { | |
164 NOTIMPLEMENTED(); | |
165 return 0; | |
166 } | |
167 virtual base::TimeDelta GetConnectTimeMicros() const OVERRIDE { | |
168 NOTIMPLEMENTED(); | |
169 return base::TimeDelta(); | |
170 } | |
171 virtual bool WasNpnNegotiated() const OVERRIDE { | |
172 return false; | |
173 } | |
174 virtual net::NextProto GetNegotiatedProtocol() const OVERRIDE { | |
175 return net::kProtoUnknown; | |
176 } | |
177 virtual bool GetSSLInfo(net::SSLInfo* ssl_info) OVERRIDE { | |
178 NOTIMPLEMENTED(); | |
179 return false; | |
180 } | |
181 | |
182 private: | |
183 MuxChannel* channel_; | |
184 | |
185 net::CompletionCallback read_callback_; | |
186 scoped_refptr<net::IOBuffer> read_buffer_; | |
187 int read_buffer_size_; | |
188 | |
189 bool write_pending_; | |
190 int write_result_; | |
191 net::CompletionCallback write_callback_; | |
192 bool write_failed_; | |
193 | |
194 net::BoundNetLog net_log_; | |
195 | |
196 DISALLOW_COPY_AND_ASSIGN(MuxSocket); | |
197 }; | |
198 | |
199 | |
200 ChannelMultiplexer::MuxChannel::MuxChannel( | |
201 ChannelMultiplexer* multiplexer, | |
202 const std::string& name, | |
203 int id) | |
204 : multiplexer_(multiplexer), | |
205 name_(name), | |
206 id_(id), | |
207 id_sent_(false), | |
208 remote_id_(kChannelIdUnknown), | |
209 socket_(NULL) { | |
210 } | |
211 | |
212 ChannelMultiplexer::MuxChannel::~MuxChannel() { | |
213 // Socket must be destroyed before the channel. | |
214 DCHECK(!socket_); | |
215 STLDeleteElements(&pending_packets_); | |
216 } | |
217 | |
218 scoped_ptr<net::StreamSocket> ChannelMultiplexer::MuxChannel::CreateSocket() { | |
219 DCHECK(!socket_); // Can't create more than one socket per channel. | |
220 scoped_ptr<MuxSocket> result(new MuxSocket(this)); | |
221 socket_ = result.get(); | |
222 return result.PassAs<net::StreamSocket>(); | |
223 } | |
224 | |
225 void ChannelMultiplexer::MuxChannel::OnIncomingPacket( | |
226 scoped_ptr<MuxPacket> packet, | |
227 const base::Closure& done_task) { | |
228 DCHECK_EQ(packet->channel_id(), remote_id_); | |
229 if (packet->data().size() > 0) { | |
230 pending_packets_.push_back(new PendingPacket(packet.Pass(), done_task)); | |
231 if (socket_) { | |
232 // Notify the socket that we have more data. | |
233 socket_->OnPacketReceived(); | |
234 } | |
235 } | |
236 } | |
237 | |
238 void ChannelMultiplexer::MuxChannel::OnWriteFailed() { | |
239 if (socket_) | |
240 socket_->OnWriteFailed(); | |
241 } | |
242 | |
243 void ChannelMultiplexer::MuxChannel::OnSocketDestroyed() { | |
244 DCHECK(socket_); | |
245 socket_ = NULL; | |
Wez
2012/08/06 23:14:07
nit: Remove the MuxChannel from the list of active
Sergey Ulanov
2012/08/07 20:12:22
We need to keep this object in order to preserve n
| |
246 } | |
247 | |
248 void ChannelMultiplexer::MuxChannel::DoWrite( | |
249 scoped_ptr<MuxPacket> packet, | |
250 const base::Closure& done_task) { | |
251 packet->set_channel_id(id_); | |
252 if (!id_sent_) { | |
253 packet->set_channel_name(name_); | |
254 id_sent_ = true; | |
255 } | |
256 multiplexer_->DoWrite(packet.Pass(), done_task); | |
257 } | |
258 | |
259 int ChannelMultiplexer::MuxChannel::DoRead(net::IOBuffer* buffer, | |
260 int buffer_len) { | |
261 int pos = 0; | |
262 while (buffer_len > 0 && !pending_packets_.empty()) { | |
263 int result = pending_packets_.front()->Read( | |
264 buffer->data() + pos, buffer_len); | |
265 DCHECK_LE(result, buffer_len); | |
266 pos += result; | |
267 buffer_len -= pos; | |
268 while (!pending_packets_.empty() && | |
269 pending_packets_.front()->is_empty()) { | |
Wez
2012/08/06 23:14:07
Do you need a while loop here? Under what circumst
Sergey Ulanov
2012/08/07 20:12:22
Done.
| |
270 delete pending_packets_.front(); | |
271 pending_packets_.erase(pending_packets_.begin()); | |
272 } | |
273 } | |
274 return pos; | |
275 } | |
276 | |
277 ChannelMultiplexer::MuxSocket::MuxSocket(MuxChannel* channel) | |
278 : channel_(channel), | |
279 read_buffer_size_(0), | |
280 write_pending_(false), | |
281 write_result_(0), | |
282 write_failed_(false) { | |
283 } | |
284 | |
285 ChannelMultiplexer::MuxSocket::~MuxSocket() { | |
Wez
2012/08/06 23:14:07
Thread check before OnSocketDestroyed()?
Sergey Ulanov
2012/08/07 20:12:22
The class inherits from NonThreadSafe and ~NonThre
Wez
2012/08/07 21:44:44
Not until after OnSocketDestroyed() has already be
| |
286 channel_->OnSocketDestroyed(); | |
287 } | |
288 | |
289 int ChannelMultiplexer::MuxSocket::Read( | |
290 net::IOBuffer* buffer, int buffer_len, | |
291 const net::CompletionCallback& callback) { | |
292 DCHECK(CalledOnValidThread()); | |
293 DCHECK(read_callback_.is_null()); | |
294 | |
295 int result = channel_->DoRead(buffer, buffer_len); | |
296 if (result == 0) { | |
297 read_buffer_ = buffer; | |
298 read_buffer_size_ = buffer_len; | |
299 read_callback_ = callback; | |
300 return net::ERR_IO_PENDING; | |
301 } | |
302 return result; | |
303 } | |
304 | |
305 int ChannelMultiplexer::MuxSocket::Write( | |
306 net::IOBuffer* buffer, int buffer_len, | |
307 const net::CompletionCallback& callback) { | |
308 DCHECK(CalledOnValidThread()); | |
309 | |
310 if (write_failed_) | |
311 return net::ERR_FAILED; | |
312 | |
313 scoped_ptr<MuxPacket> packet(new MuxPacket()); | |
314 size_t size = std::min(kMaxPacketSize, buffer_len); | |
315 packet->mutable_data()->assign(buffer->data(), size); | |
316 | |
317 write_pending_ = true; | |
318 channel_->DoWrite(packet.Pass(), base::Bind( | |
319 &ChannelMultiplexer::MuxSocket::OnWriteComplete, AsWeakPtr())); | |
Wez
2012/08/06 23:14:07
nit: It's a little confusing that we set write_pen
Sergey Ulanov
2012/08/07 20:12:22
The semantics used by BufferedSocketWriter is easi
| |
320 | |
321 if (write_pending_) { | |
322 DCHECK(write_callback_.is_null()); | |
323 write_callback_ = callback; | |
324 write_result_ = size; | |
325 return net::ERR_IO_PENDING; | |
326 } | |
327 | |
328 return write_failed_ ? net::ERR_FAILED : size; | |
329 } | |
330 | |
331 void ChannelMultiplexer::MuxSocket::OnWriteComplete() { | |
332 write_pending_ = false; | |
333 if (!write_callback_.is_null()) { | |
334 net::CompletionCallback cb; | |
335 std::swap(cb, write_callback_); | |
336 cb.Run(write_result_); | |
337 } | |
338 } | |
339 | |
340 void ChannelMultiplexer::MuxSocket::OnWriteFailed() { | |
341 if (!write_callback_.is_null()) { | |
342 net::CompletionCallback cb; | |
343 std::swap(cb, write_callback_); | |
344 cb.Run(net::ERR_FAILED); | |
Wez
2012/08/06 23:14:07
Don't you need to do this after setting write_fail
Sergey Ulanov
2012/08/07 20:12:22
Good point. Fixed.
| |
345 } | |
346 write_failed_ = true; | |
Wez
2012/08/06 23:14:07
nit: Do you actually need to keep a write_failed f
Sergey Ulanov
2012/08/07 20:12:22
Done.
| |
347 } | |
348 | |
349 void ChannelMultiplexer::MuxSocket::OnPacketReceived() { | |
350 if (!read_callback_.is_null()) { | |
351 int result = channel_->DoRead(read_buffer_, read_buffer_size_); | |
352 read_buffer_ = NULL; | |
353 DCHECK_GT(result, 0); | |
354 net::CompletionCallback cb; | |
355 std::swap(cb, read_callback_); | |
356 cb.Run(result); | |
357 } | |
358 } | |
359 | |
360 ChannelMultiplexer::ChannelMultiplexer(ChannelFactory* factory) | |
361 : base_channel_failed_(false), | |
362 next_channel_id_(0) { | |
363 factory->CreateStreamChannel( | |
364 kMuxChannelName, base::Bind(&ChannelMultiplexer::OnBaseChannelReady, | |
365 base::Unretained(this))); | |
366 } | |
367 | |
368 ChannelMultiplexer::~ChannelMultiplexer() { | |
369 STLDeleteValues(&channels_); | |
370 } | |
371 | |
372 void ChannelMultiplexer::CreateStreamChannel( | |
373 const std::string& name, | |
374 const StreamChannelCallback& callback) { | |
375 if (base_channel_.get()) { | |
376 // Already have |base_channel_|. Create new multiplexed channel | |
377 // synchronously. | |
378 callback.Run(GetOrCreateChannel(name)->CreateSocket()); | |
379 } else if (base_channel_failed_) { | |
380 // Fail synchronously if we failed to create |base_channel_|. | |
381 callback.Run(scoped_ptr<net::StreamSocket>()); | |
382 } else { | |
383 // Still waiting for the |base_channel_|. | |
384 pending_channels_.push_back(PendingChannel(name, callback)); | |
385 } | |
386 } | |
387 | |
388 void ChannelMultiplexer::CreateDatagramChannel( | |
389 const std::string& name, | |
390 const DatagramChannelCallback& callback) { | |
391 NOTIMPLEMENTED(); | |
392 callback.Run(scoped_ptr<net::Socket>()); | |
393 } | |
394 | |
395 void ChannelMultiplexer::OnBaseChannelReady( | |
396 scoped_ptr<net::StreamSocket> socket) { | |
397 base_channel_ = socket.Pass(); | |
398 | |
399 if (!base_channel_.get()) { | |
400 base_channel_failed_ = true; | |
401 | |
402 // Notify all callers that we can't create any channels. | |
403 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | |
404 it != pending_channels_.end(); ++it) { | |
405 it->callback.Run(scoped_ptr<net::StreamSocket>()); | |
406 } | |
407 pending_channels_.clear(); | |
408 return; | |
409 } | |
410 | |
411 // Initialize reader and writer. | |
412 reader_.Init(base_channel_.get(), | |
413 base::Bind(&ChannelMultiplexer::OnIncomingPacket, | |
414 base::Unretained(this))); | |
415 writer_.Init(base_channel_.get(), | |
416 base::Bind(&ChannelMultiplexer::OnWriteFailed, | |
417 base::Unretained(this))); | |
418 | |
419 // Now create all pending channels. | |
420 for (std::list<PendingChannel>::iterator it = pending_channels_.begin(); | |
421 it != pending_channels_.end(); ++it) { | |
422 it->callback.Run(GetOrCreateChannel(it->name)->CreateSocket()); | |
423 } | |
424 pending_channels_.clear(); | |
425 } | |
426 | |
427 ChannelMultiplexer::MuxChannel* ChannelMultiplexer::GetOrCreateChannel( | |
428 const std::string& name) { | |
429 // Check if we already have a channel with the requested name. | |
430 std::map<std::string, MuxChannel*>::iterator it = channels_.find(name); | |
431 if (it != channels_.end()) | |
432 return it->second; | |
433 | |
434 // Create a new channel if we haven't found existing one. | |
435 MuxChannel* channel = new MuxChannel(this, name, next_channel_id_); | |
436 ++next_channel_id_; | |
437 channels_[channel->name()] = channel; | |
438 return channel; | |
439 } | |
440 | |
441 | |
442 void ChannelMultiplexer::OnWriteFailed(int error) { | |
443 for (std::map<std::string, MuxChannel*>::iterator it = channels_.begin(); | |
444 it != channels_.end(); ++it) { | |
445 it->second->OnWriteFailed(); | |
Wez
2012/08/06 23:14:07
This triggers caller-supplied callbacks, which mig
Sergey Ulanov
2012/08/07 20:12:22
Fixed this method to handle the case when multiple
| |
446 } | |
447 } | |
448 | |
449 void ChannelMultiplexer::OnIncomingPacket(scoped_ptr<MuxPacket> packet, | |
450 const base::Closure& done_task) { | |
451 if (!packet->has_channel_id()) { | |
452 LOG(ERROR) << "Received packet without channel_id."; | |
453 done_task.Run(); | |
454 return; | |
455 } | |
456 | |
457 int remote_id = packet->channel_id(); | |
458 MuxChannel* channel = NULL; | |
459 std::map<int, MuxChannel*>::iterator it = | |
460 channels_by_remote_id_.find(remote_id); | |
461 if (it != channels_by_remote_id_.end()) { | |
462 channel = it->second; | |
463 } else { | |
464 // This is a new |channel_id| we haven't seen before. Look it up by name. | |
465 if (!packet->has_channel_name()) { | |
466 LOG(ERROR) << "Received packet with unknown channel_id and " | |
467 "without channel_name."; | |
468 done_task.Run(); | |
469 return; | |
470 } | |
471 channel = GetOrCreateChannel(packet->channel_name()); | |
Wez
2012/08/06 23:14:07
This has the disadvantage that if a peer sends lot
Sergey Ulanov
2012/08/07 20:12:22
MessageReader doesn't try to read from the channel
| |
472 channel->set_remote_id(remote_id); | |
473 channels_by_remote_id_[remote_id] = channel; | |
474 } | |
475 | |
476 channel->OnIncomingPacket(packet.Pass(), done_task); | |
477 } | |
478 | |
479 void ChannelMultiplexer::DoWrite(scoped_ptr<MuxPacket> packet, | |
480 const base::Closure& done_task) { | |
Wez
2012/08/06 23:14:07
nit: Indentation.
Sergey Ulanov
2012/08/07 20:12:22
Done.
| |
481 writer_.Write(SerializeAndFrameMessage(*packet), done_task); | |
482 } | |
483 | |
484 } // namespace protocol | |
485 } // namespace remoting | |
OLD | NEW |