OLD | NEW |
1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "remoting/jingle_glue/ssl_socket_adapter.h" | 5 #include "remoting/jingle_glue/ssl_socket_adapter.h" |
6 | 6 |
7 #include "base/base64.h" | 7 #include "base/base64.h" |
8 #include "base/compiler_specific.h" | 8 #include "base/compiler_specific.h" |
9 #include "base/message_loop.h" | 9 #include "base/message_loop.h" |
10 #include "jingle/glue/utils.h" | 10 #include "jingle/glue/utils.h" |
11 #include "net/base/address_list.h" | 11 #include "net/base/address_list.h" |
12 #include "net/base/cert_verifier.h" | 12 #include "net/base/cert_verifier.h" |
13 #include "net/base/host_port_pair.h" | 13 #include "net/base/host_port_pair.h" |
14 #include "net/base/net_errors.h" | 14 #include "net/base/net_errors.h" |
15 #include "net/base/ssl_config_service.h" | 15 #include "net/base/ssl_config_service.h" |
16 #include "net/base/sys_addrinfo.h" | 16 #include "net/base/sys_addrinfo.h" |
17 #include "net/socket/client_socket_factory.h" | 17 #include "net/socket/client_socket_factory.h" |
18 #include "net/url_request/url_request_context.h" | 18 #include "net/url_request/url_request_context.h" |
19 | 19 |
20 namespace remoting { | 20 namespace remoting { |
21 | 21 |
22 SSLSocketAdapter* SSLSocketAdapter::Create(AsyncSocket* socket) { | 22 SSLSocketAdapter* SSLSocketAdapter::Create(AsyncSocket* socket) { |
23 return new SSLSocketAdapter(socket); | 23 return new SSLSocketAdapter(socket); |
24 } | 24 } |
25 | 25 |
26 SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket) | 26 SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket) |
27 : SSLAdapter(socket), | 27 : SSLAdapter(socket), |
28 ignore_bad_cert_(false), | 28 ignore_bad_cert_(false), |
29 cert_verifier_(new net::CertVerifier()), | 29 cert_verifier_(new net::CertVerifier()), |
30 ALLOW_THIS_IN_INITIALIZER_LIST( | |
31 connected_callback_(this, &SSLSocketAdapter::OnConnected)), | |
32 ALLOW_THIS_IN_INITIALIZER_LIST( | |
33 read_callback_(this, &SSLSocketAdapter::OnRead)), | |
34 ALLOW_THIS_IN_INITIALIZER_LIST( | |
35 write_callback_(this, &SSLSocketAdapter::OnWrite)), | |
36 ssl_state_(SSLSTATE_NONE), | 30 ssl_state_(SSLSTATE_NONE), |
37 read_state_(IOSTATE_NONE), | 31 read_state_(IOSTATE_NONE), |
38 write_state_(IOSTATE_NONE) { | 32 write_state_(IOSTATE_NONE) { |
39 transport_socket_ = new TransportSocket(socket, this); | 33 transport_socket_ = new TransportSocket(socket, this); |
40 } | 34 } |
41 | 35 |
42 SSLSocketAdapter::~SSLSocketAdapter() { | 36 SSLSocketAdapter::~SSLSocketAdapter() { |
43 } | 37 } |
44 | 38 |
45 int SSLSocketAdapter::StartSSL(const char* hostname, bool restartable) { | 39 int SSLSocketAdapter::StartSSL(const char* hostname, bool restartable) { |
(...skipping 24 matching lines...) Expand all Loading... |
70 net::SSLConfig ssl_config; | 64 net::SSLConfig ssl_config; |
71 net::SSLClientSocketContext context; | 65 net::SSLClientSocketContext context; |
72 context.cert_verifier = cert_verifier_.get(); | 66 context.cert_verifier = cert_verifier_.get(); |
73 | 67 |
74 transport_socket_->set_addr(talk_base::SocketAddress(hostname_, 0)); | 68 transport_socket_->set_addr(talk_base::SocketAddress(hostname_, 0)); |
75 ssl_socket_.reset( | 69 ssl_socket_.reset( |
76 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( | 70 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( |
77 transport_socket_, net::HostPortPair(hostname_, 443), ssl_config, | 71 transport_socket_, net::HostPortPair(hostname_, 443), ssl_config, |
78 NULL /* ssl_host_info */, context)); | 72 NULL /* ssl_host_info */, context)); |
79 | 73 |
80 int result = ssl_socket_->Connect(&connected_callback_); | 74 int result = ssl_socket_->Connect( |
| 75 base::Bind(&SSLSocketAdapter::OnConnected, base::Unretained(this))); |
81 | 76 |
82 if (result == net::ERR_IO_PENDING || result == net::OK) { | 77 if (result == net::ERR_IO_PENDING || result == net::OK) { |
83 return 0; | 78 return 0; |
84 } else { | 79 } else { |
85 LOG(ERROR) << "Could not start SSL: " << net::ErrorToString(result); | 80 LOG(ERROR) << "Could not start SSL: " << net::ErrorToString(result); |
86 return result; | 81 return result; |
87 } | 82 } |
88 } | 83 } |
89 | 84 |
90 int SSLSocketAdapter::Send(const void* buf, size_t len) { | 85 int SSLSocketAdapter::Send(const void* buf, size_t len) { |
91 if (ssl_state_ != SSLSTATE_CONNECTED) { | 86 if (ssl_state_ != SSLSTATE_CONNECTED) { |
92 return AsyncSocketAdapter::Send(buf, len); | 87 return AsyncSocketAdapter::Send(buf, len); |
93 } else { | 88 } else { |
94 scoped_refptr<net::IOBuffer> transport_buf(new net::IOBuffer(len)); | 89 scoped_refptr<net::IOBuffer> transport_buf(new net::IOBuffer(len)); |
95 memcpy(transport_buf->data(), buf, len); | 90 memcpy(transport_buf->data(), buf, len); |
96 | 91 |
97 int result = ssl_socket_->Write(transport_buf, len, NULL); | 92 int result = ssl_socket_->Write(transport_buf, len, |
| 93 net::CompletionCallback()); |
98 if (result == net::ERR_IO_PENDING) { | 94 if (result == net::ERR_IO_PENDING) { |
99 SetError(EWOULDBLOCK); | 95 SetError(EWOULDBLOCK); |
100 } | 96 } |
101 transport_buf = NULL; | 97 transport_buf = NULL; |
102 return result; | 98 return result; |
103 } | 99 } |
104 } | 100 } |
105 | 101 |
106 int SSLSocketAdapter::Recv(void* buf, size_t len) { | 102 int SSLSocketAdapter::Recv(void* buf, size_t len) { |
107 switch (ssl_state_) { | 103 switch (ssl_state_) { |
108 case SSLSTATE_NONE: | 104 case SSLSTATE_NONE: |
109 return AsyncSocketAdapter::Recv(buf, len); | 105 return AsyncSocketAdapter::Recv(buf, len); |
110 | 106 |
111 case SSLSTATE_WAIT: | 107 case SSLSTATE_WAIT: |
112 SetError(EWOULDBLOCK); | 108 SetError(EWOULDBLOCK); |
113 return -1; | 109 return -1; |
114 | 110 |
115 case SSLSTATE_CONNECTED: | 111 case SSLSTATE_CONNECTED: |
116 switch (read_state_) { | 112 switch (read_state_) { |
117 case IOSTATE_NONE: { | 113 case IOSTATE_NONE: { |
118 transport_buf_ = new net::IOBuffer(len); | 114 transport_buf_ = new net::IOBuffer(len); |
119 int result = ssl_socket_->Read(transport_buf_, len, &read_callback_); | 115 int result = ssl_socket_->Read( |
| 116 transport_buf_, len, |
| 117 base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this))); |
120 if (result >= 0) { | 118 if (result >= 0) { |
121 memcpy(buf, transport_buf_->data(), len); | 119 memcpy(buf, transport_buf_->data(), len); |
122 } | 120 } |
123 | 121 |
124 if (result == net::ERR_IO_PENDING) { | 122 if (result == net::ERR_IO_PENDING) { |
125 read_state_ = IOSTATE_PENDING; | 123 read_state_ = IOSTATE_PENDING; |
126 SetError(EWOULDBLOCK); | 124 SetError(EWOULDBLOCK); |
127 } else { | 125 } else { |
128 if (result < 0) { | 126 if (result < 0) { |
129 SetError(result); | 127 SetError(result); |
(...skipping 50 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
180 int result = BeginSSL(); | 178 int result = BeginSSL(); |
181 if (0 != result) { | 179 if (0 != result) { |
182 // TODO(zork): Handle this case gracefully. | 180 // TODO(zork): Handle this case gracefully. |
183 LOG(WARNING) << "BeginSSL() failed with " << result; | 181 LOG(WARNING) << "BeginSSL() failed with " << result; |
184 } | 182 } |
185 } | 183 } |
186 } | 184 } |
187 | 185 |
188 TransportSocket::TransportSocket(talk_base::AsyncSocket* socket, | 186 TransportSocket::TransportSocket(talk_base::AsyncSocket* socket, |
189 SSLSocketAdapter *ssl_adapter) | 187 SSLSocketAdapter *ssl_adapter) |
190 : old_read_callback_(NULL), | 188 : read_buffer_len_(0), |
191 write_callback_(NULL), | |
192 read_buffer_len_(0), | |
193 write_buffer_len_(0), | 189 write_buffer_len_(0), |
194 socket_(socket), | 190 socket_(socket), |
195 was_used_to_convey_data_(false) { | 191 was_used_to_convey_data_(false) { |
196 socket_->SignalReadEvent.connect(this, &TransportSocket::OnReadEvent); | 192 socket_->SignalReadEvent.connect(this, &TransportSocket::OnReadEvent); |
197 socket_->SignalWriteEvent.connect(this, &TransportSocket::OnWriteEvent); | 193 socket_->SignalWriteEvent.connect(this, &TransportSocket::OnWriteEvent); |
198 } | 194 } |
199 | 195 |
200 TransportSocket::~TransportSocket() { | 196 TransportSocket::~TransportSocket() { |
201 } | 197 } |
202 | 198 |
203 int TransportSocket::Connect(net::OldCompletionCallback* callback) { | |
204 // Connect is never called by SSLClientSocket, instead SSLSocketAdapter | |
205 // calls Connect() on socket_ directly. | |
206 NOTREACHED(); | |
207 return false; | |
208 } | |
209 int TransportSocket::Connect(const net::CompletionCallback& callback) { | 199 int TransportSocket::Connect(const net::CompletionCallback& callback) { |
210 // Connect is never called by SSLClientSocket, instead SSLSocketAdapter | 200 // Connect is never called by SSLClientSocket, instead SSLSocketAdapter |
211 // calls Connect() on socket_ directly. | 201 // calls Connect() on socket_ directly. |
212 NOTREACHED(); | 202 NOTREACHED(); |
213 return false; | 203 return false; |
214 } | 204 } |
215 | 205 |
216 void TransportSocket::Disconnect() { | 206 void TransportSocket::Disconnect() { |
217 socket_->Close(); | 207 socket_->Close(); |
218 } | 208 } |
(...skipping 62 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
281 NOTREACHED(); | 271 NOTREACHED(); |
282 return -1; | 272 return -1; |
283 } | 273 } |
284 | 274 |
285 base::TimeDelta TransportSocket::GetConnectTimeMicros() const { | 275 base::TimeDelta TransportSocket::GetConnectTimeMicros() const { |
286 NOTREACHED(); | 276 NOTREACHED(); |
287 return base::TimeDelta::FromMicroseconds(-1); | 277 return base::TimeDelta::FromMicroseconds(-1); |
288 } | 278 } |
289 | 279 |
290 int TransportSocket::Read(net::IOBuffer* buf, int buf_len, | 280 int TransportSocket::Read(net::IOBuffer* buf, int buf_len, |
291 net::OldCompletionCallback* callback) { | 281 const net::CompletionCallback& callback) { |
292 DCHECK(buf); | 282 DCHECK(buf); |
293 DCHECK(!old_read_callback_ && read_callback_.is_null()); | 283 DCHECK(read_callback_.is_null()); |
294 DCHECK(!read_buffer_.get()); | 284 DCHECK(!read_buffer_.get()); |
295 int result = socket_->Recv(buf->data(), buf_len); | 285 int result = socket_->Recv(buf->data(), buf_len); |
296 if (result < 0) { | 286 if (result < 0) { |
297 result = net::MapSystemError(socket_->GetError()); | |
298 if (result == net::ERR_IO_PENDING) { | |
299 old_read_callback_ = callback; | |
300 read_buffer_ = buf; | |
301 read_buffer_len_ = buf_len; | |
302 } | |
303 } | |
304 if (result != net::ERR_IO_PENDING) | |
305 was_used_to_convey_data_ = true; | |
306 return result; | |
307 } | |
308 int TransportSocket::Read(net::IOBuffer* buf, int buf_len, | |
309 const net::CompletionCallback& callback) { | |
310 DCHECK(buf); | |
311 DCHECK(!old_read_callback_ && read_callback_.is_null()); | |
312 DCHECK(!read_buffer_.get()); | |
313 int result = socket_->Recv(buf->data(), buf_len); | |
314 if (result < 0) { | |
315 result = net::MapSystemError(socket_->GetError()); | 287 result = net::MapSystemError(socket_->GetError()); |
316 if (result == net::ERR_IO_PENDING) { | 288 if (result == net::ERR_IO_PENDING) { |
317 read_callback_ = callback; | 289 read_callback_ = callback; |
318 read_buffer_ = buf; | 290 read_buffer_ = buf; |
319 read_buffer_len_ = buf_len; | 291 read_buffer_len_ = buf_len; |
320 } | 292 } |
321 } | 293 } |
322 if (result != net::ERR_IO_PENDING) | 294 if (result != net::ERR_IO_PENDING) |
323 was_used_to_convey_data_ = true; | 295 was_used_to_convey_data_ = true; |
324 return result; | 296 return result; |
325 } | 297 } |
326 | 298 |
327 int TransportSocket::Write(net::IOBuffer* buf, int buf_len, | 299 int TransportSocket::Write(net::IOBuffer* buf, int buf_len, |
328 net::OldCompletionCallback* callback) { | 300 const net::CompletionCallback& callback) { |
329 DCHECK(buf); | 301 DCHECK(buf); |
330 DCHECK(!write_callback_); | 302 DCHECK(write_callback_.is_null()); |
331 DCHECK(!write_buffer_.get()); | 303 DCHECK(!write_buffer_.get()); |
332 int result = socket_->Send(buf->data(), buf_len); | 304 int result = socket_->Send(buf->data(), buf_len); |
333 if (result < 0) { | 305 if (result < 0) { |
334 result = net::MapSystemError(socket_->GetError()); | 306 result = net::MapSystemError(socket_->GetError()); |
335 if (result == net::ERR_IO_PENDING) { | 307 if (result == net::ERR_IO_PENDING) { |
336 write_callback_ = callback; | 308 write_callback_ = callback; |
337 write_buffer_ = buf; | 309 write_buffer_ = buf; |
338 write_buffer_len_ = buf_len; | 310 write_buffer_len_ = buf_len; |
339 } | 311 } |
340 } | 312 } |
341 if (result != net::ERR_IO_PENDING) | 313 if (result != net::ERR_IO_PENDING) |
342 was_used_to_convey_data_ = true; | 314 was_used_to_convey_data_ = true; |
343 return result; | 315 return result; |
344 } | 316 } |
345 | 317 |
346 bool TransportSocket::SetReceiveBufferSize(int32 size) { | 318 bool TransportSocket::SetReceiveBufferSize(int32 size) { |
347 // Not implemented. | 319 // Not implemented. |
348 return false; | 320 return false; |
349 } | 321 } |
350 | 322 |
351 bool TransportSocket::SetSendBufferSize(int32 size) { | 323 bool TransportSocket::SetSendBufferSize(int32 size) { |
352 // Not implemented. | 324 // Not implemented. |
353 return false; | 325 return false; |
354 } | 326 } |
355 | 327 |
356 void TransportSocket::OnReadEvent(talk_base::AsyncSocket* socket) { | 328 void TransportSocket::OnReadEvent(talk_base::AsyncSocket* socket) { |
357 if (old_read_callback_ || !read_callback_.is_null()) { | 329 if (!read_callback_.is_null()) { |
358 DCHECK(read_buffer_.get()); | 330 DCHECK(read_buffer_.get()); |
359 net::OldCompletionCallback* old_callback = old_read_callback_; | |
360 net::CompletionCallback callback = read_callback_; | 331 net::CompletionCallback callback = read_callback_; |
361 scoped_refptr<net::IOBuffer> buffer = read_buffer_; | 332 scoped_refptr<net::IOBuffer> buffer = read_buffer_; |
362 int buffer_len = read_buffer_len_; | 333 int buffer_len = read_buffer_len_; |
363 | 334 |
364 old_read_callback_ = NULL; | |
365 read_callback_.Reset(); | 335 read_callback_.Reset(); |
366 read_buffer_ = NULL; | 336 read_buffer_ = NULL; |
367 read_buffer_len_ = 0; | 337 read_buffer_len_ = 0; |
368 | 338 |
369 int result = socket_->Recv(buffer->data(), buffer_len); | 339 int result = socket_->Recv(buffer->data(), buffer_len); |
370 if (result < 0) { | 340 if (result < 0) { |
371 result = net::MapSystemError(socket_->GetError()); | 341 result = net::MapSystemError(socket_->GetError()); |
372 if (result == net::ERR_IO_PENDING) { | 342 if (result == net::ERR_IO_PENDING) { |
373 old_read_callback_ = old_callback; | |
374 read_callback_ = callback; | 343 read_callback_ = callback; |
375 read_buffer_ = buffer; | 344 read_buffer_ = buffer; |
376 read_buffer_len_ = buffer_len; | 345 read_buffer_len_ = buffer_len; |
377 return; | 346 return; |
378 } | 347 } |
379 } | 348 } |
380 was_used_to_convey_data_ = true; | 349 was_used_to_convey_data_ = true; |
381 if (old_callback) | 350 callback.Run(result); |
382 old_callback->RunWithParams(Tuple1<int>(result)); | |
383 else | |
384 callback.Run(result); | |
385 } | 351 } |
386 } | 352 } |
387 | 353 |
388 void TransportSocket::OnWriteEvent(talk_base::AsyncSocket* socket) { | 354 void TransportSocket::OnWriteEvent(talk_base::AsyncSocket* socket) { |
389 if (write_callback_) { | 355 if (!write_callback_.is_null()) { |
390 DCHECK(write_buffer_.get()); | 356 DCHECK(write_buffer_.get()); |
391 net::OldCompletionCallback* callback = write_callback_; | 357 net::CompletionCallback callback = write_callback_; |
392 scoped_refptr<net::IOBuffer> buffer = write_buffer_; | 358 scoped_refptr<net::IOBuffer> buffer = write_buffer_; |
393 int buffer_len = write_buffer_len_; | 359 int buffer_len = write_buffer_len_; |
394 | 360 |
395 write_callback_ = NULL; | 361 write_callback_.Reset(); |
396 write_buffer_ = NULL; | 362 write_buffer_ = NULL; |
397 write_buffer_len_ = 0; | 363 write_buffer_len_ = 0; |
398 | 364 |
399 int result = socket_->Send(buffer->data(), buffer_len); | 365 int result = socket_->Send(buffer->data(), buffer_len); |
400 if (result < 0) { | 366 if (result < 0) { |
401 result = net::MapSystemError(socket_->GetError()); | 367 result = net::MapSystemError(socket_->GetError()); |
402 if (result == net::ERR_IO_PENDING) { | 368 if (result == net::ERR_IO_PENDING) { |
403 write_callback_ = callback; | 369 write_callback_ = callback; |
404 write_buffer_ = buffer; | 370 write_buffer_ = buffer; |
405 write_buffer_len_ = buffer_len; | 371 write_buffer_len_ = buffer_len; |
406 return; | 372 return; |
407 } | 373 } |
408 } | 374 } |
409 was_used_to_convey_data_ = true; | 375 was_used_to_convey_data_ = true; |
410 callback->RunWithParams(Tuple1<int>(result)); | 376 callback.Run(result); |
411 } | 377 } |
412 } | 378 } |
413 | 379 |
414 } // namespace remoting | 380 } // namespace remoting |
OLD | NEW |