| OLD | NEW |
| 1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "remoting/jingle_glue/ssl_socket_adapter.h" | 5 #include "remoting/jingle_glue/ssl_socket_adapter.h" |
| 6 | 6 |
| 7 #include "base/base64.h" | 7 #include "base/base64.h" |
| 8 #include "base/compiler_specific.h" | 8 #include "base/compiler_specific.h" |
| 9 #include "base/message_loop.h" | 9 #include "base/message_loop.h" |
| 10 #include "jingle/glue/utils.h" | 10 #include "jingle/glue/utils.h" |
| 11 #include "net/base/address_list.h" | 11 #include "net/base/address_list.h" |
| 12 #include "net/base/cert_verifier.h" | 12 #include "net/base/cert_verifier.h" |
| 13 #include "net/base/host_port_pair.h" | 13 #include "net/base/host_port_pair.h" |
| 14 #include "net/base/net_errors.h" | 14 #include "net/base/net_errors.h" |
| 15 #include "net/base/ssl_config_service.h" | 15 #include "net/base/ssl_config_service.h" |
| 16 #include "net/base/sys_addrinfo.h" | 16 #include "net/base/sys_addrinfo.h" |
| 17 #include "net/socket/client_socket_factory.h" | 17 #include "net/socket/client_socket_factory.h" |
| 18 #include "net/url_request/url_request_context.h" | 18 #include "net/url_request/url_request_context.h" |
| 19 | 19 |
| 20 namespace remoting { | 20 namespace remoting { |
| 21 | 21 |
| 22 SSLSocketAdapter* SSLSocketAdapter::Create(AsyncSocket* socket) { | 22 SSLSocketAdapter* SSLSocketAdapter::Create(AsyncSocket* socket) { |
| 23 return new SSLSocketAdapter(socket); | 23 return new SSLSocketAdapter(socket); |
| 24 } | 24 } |
| 25 | 25 |
| 26 SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket) | 26 SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket) |
| 27 : SSLAdapter(socket), | 27 : SSLAdapter(socket), |
| 28 ignore_bad_cert_(false), | 28 ignore_bad_cert_(false), |
| 29 cert_verifier_(new net::CertVerifier()), | 29 cert_verifier_(new net::CertVerifier()), |
| 30 ALLOW_THIS_IN_INITIALIZER_LIST( | |
| 31 connected_callback_(this, &SSLSocketAdapter::OnConnected)), | |
| 32 ALLOW_THIS_IN_INITIALIZER_LIST( | |
| 33 read_callback_(this, &SSLSocketAdapter::OnRead)), | |
| 34 ALLOW_THIS_IN_INITIALIZER_LIST( | |
| 35 write_callback_(this, &SSLSocketAdapter::OnWrite)), | |
| 36 ssl_state_(SSLSTATE_NONE), | 30 ssl_state_(SSLSTATE_NONE), |
| 37 read_state_(IOSTATE_NONE), | 31 read_state_(IOSTATE_NONE), |
| 38 write_state_(IOSTATE_NONE) { | 32 write_state_(IOSTATE_NONE) { |
| 39 transport_socket_ = new TransportSocket(socket, this); | 33 transport_socket_ = new TransportSocket(socket, this); |
| 40 } | 34 } |
| 41 | 35 |
| 42 SSLSocketAdapter::~SSLSocketAdapter() { | 36 SSLSocketAdapter::~SSLSocketAdapter() { |
| 43 } | 37 } |
| 44 | 38 |
| 45 int SSLSocketAdapter::StartSSL(const char* hostname, bool restartable) { | 39 int SSLSocketAdapter::StartSSL(const char* hostname, bool restartable) { |
| (...skipping 24 matching lines...) Expand all Loading... |
| 70 net::SSLConfig ssl_config; | 64 net::SSLConfig ssl_config; |
| 71 net::SSLClientSocketContext context; | 65 net::SSLClientSocketContext context; |
| 72 context.cert_verifier = cert_verifier_.get(); | 66 context.cert_verifier = cert_verifier_.get(); |
| 73 | 67 |
| 74 transport_socket_->set_addr(talk_base::SocketAddress(hostname_, 0)); | 68 transport_socket_->set_addr(talk_base::SocketAddress(hostname_, 0)); |
| 75 ssl_socket_.reset( | 69 ssl_socket_.reset( |
| 76 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( | 70 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( |
| 77 transport_socket_, net::HostPortPair(hostname_, 443), ssl_config, | 71 transport_socket_, net::HostPortPair(hostname_, 443), ssl_config, |
| 78 NULL /* ssl_host_info */, context)); | 72 NULL /* ssl_host_info */, context)); |
| 79 | 73 |
| 80 int result = ssl_socket_->Connect(&connected_callback_); | 74 int result = ssl_socket_->Connect( |
| 75 base::Bind(&SSLSocketAdapter::OnConnected, base::Unretained(this))); |
| 81 | 76 |
| 82 if (result == net::ERR_IO_PENDING || result == net::OK) { | 77 if (result == net::ERR_IO_PENDING || result == net::OK) { |
| 83 return 0; | 78 return 0; |
| 84 } else { | 79 } else { |
| 85 LOG(ERROR) << "Could not start SSL: " << net::ErrorToString(result); | 80 LOG(ERROR) << "Could not start SSL: " << net::ErrorToString(result); |
| 86 return result; | 81 return result; |
| 87 } | 82 } |
| 88 } | 83 } |
| 89 | 84 |
| 90 int SSLSocketAdapter::Send(const void* buf, size_t len) { | 85 int SSLSocketAdapter::Send(const void* buf, size_t len) { |
| 91 if (ssl_state_ != SSLSTATE_CONNECTED) { | 86 if (ssl_state_ != SSLSTATE_CONNECTED) { |
| 92 return AsyncSocketAdapter::Send(buf, len); | 87 return AsyncSocketAdapter::Send(buf, len); |
| 93 } else { | 88 } else { |
| 94 scoped_refptr<net::IOBuffer> transport_buf(new net::IOBuffer(len)); | 89 scoped_refptr<net::IOBuffer> transport_buf(new net::IOBuffer(len)); |
| 95 memcpy(transport_buf->data(), buf, len); | 90 memcpy(transport_buf->data(), buf, len); |
| 96 | 91 |
| 97 int result = ssl_socket_->Write(transport_buf, len, NULL); | 92 int result = ssl_socket_->Write(transport_buf, len, |
| 93 net::CompletionCallback()); |
| 98 if (result == net::ERR_IO_PENDING) { | 94 if (result == net::ERR_IO_PENDING) { |
| 99 SetError(EWOULDBLOCK); | 95 SetError(EWOULDBLOCK); |
| 100 } | 96 } |
| 101 transport_buf = NULL; | 97 transport_buf = NULL; |
| 102 return result; | 98 return result; |
| 103 } | 99 } |
| 104 } | 100 } |
| 105 | 101 |
| 106 int SSLSocketAdapter::Recv(void* buf, size_t len) { | 102 int SSLSocketAdapter::Recv(void* buf, size_t len) { |
| 107 switch (ssl_state_) { | 103 switch (ssl_state_) { |
| 108 case SSLSTATE_NONE: | 104 case SSLSTATE_NONE: |
| 109 return AsyncSocketAdapter::Recv(buf, len); | 105 return AsyncSocketAdapter::Recv(buf, len); |
| 110 | 106 |
| 111 case SSLSTATE_WAIT: | 107 case SSLSTATE_WAIT: |
| 112 SetError(EWOULDBLOCK); | 108 SetError(EWOULDBLOCK); |
| 113 return -1; | 109 return -1; |
| 114 | 110 |
| 115 case SSLSTATE_CONNECTED: | 111 case SSLSTATE_CONNECTED: |
| 116 switch (read_state_) { | 112 switch (read_state_) { |
| 117 case IOSTATE_NONE: { | 113 case IOSTATE_NONE: { |
| 118 transport_buf_ = new net::IOBuffer(len); | 114 transport_buf_ = new net::IOBuffer(len); |
| 119 int result = ssl_socket_->Read(transport_buf_, len, &read_callback_); | 115 int result = ssl_socket_->Read( |
| 116 transport_buf_, len, |
| 117 base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this))); |
| 120 if (result >= 0) { | 118 if (result >= 0) { |
| 121 memcpy(buf, transport_buf_->data(), len); | 119 memcpy(buf, transport_buf_->data(), len); |
| 122 } | 120 } |
| 123 | 121 |
| 124 if (result == net::ERR_IO_PENDING) { | 122 if (result == net::ERR_IO_PENDING) { |
| 125 read_state_ = IOSTATE_PENDING; | 123 read_state_ = IOSTATE_PENDING; |
| 126 SetError(EWOULDBLOCK); | 124 SetError(EWOULDBLOCK); |
| 127 } else { | 125 } else { |
| 128 if (result < 0) { | 126 if (result < 0) { |
| 129 SetError(result); | 127 SetError(result); |
| (...skipping 50 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 180 int result = BeginSSL(); | 178 int result = BeginSSL(); |
| 181 if (0 != result) { | 179 if (0 != result) { |
| 182 // TODO(zork): Handle this case gracefully. | 180 // TODO(zork): Handle this case gracefully. |
| 183 LOG(WARNING) << "BeginSSL() failed with " << result; | 181 LOG(WARNING) << "BeginSSL() failed with " << result; |
| 184 } | 182 } |
| 185 } | 183 } |
| 186 } | 184 } |
| 187 | 185 |
| 188 TransportSocket::TransportSocket(talk_base::AsyncSocket* socket, | 186 TransportSocket::TransportSocket(talk_base::AsyncSocket* socket, |
| 189 SSLSocketAdapter *ssl_adapter) | 187 SSLSocketAdapter *ssl_adapter) |
| 190 : old_read_callback_(NULL), | 188 : read_buffer_len_(0), |
| 191 write_callback_(NULL), | |
| 192 read_buffer_len_(0), | |
| 193 write_buffer_len_(0), | 189 write_buffer_len_(0), |
| 194 socket_(socket), | 190 socket_(socket), |
| 195 was_used_to_convey_data_(false) { | 191 was_used_to_convey_data_(false) { |
| 196 socket_->SignalReadEvent.connect(this, &TransportSocket::OnReadEvent); | 192 socket_->SignalReadEvent.connect(this, &TransportSocket::OnReadEvent); |
| 197 socket_->SignalWriteEvent.connect(this, &TransportSocket::OnWriteEvent); | 193 socket_->SignalWriteEvent.connect(this, &TransportSocket::OnWriteEvent); |
| 198 } | 194 } |
| 199 | 195 |
| 200 TransportSocket::~TransportSocket() { | 196 TransportSocket::~TransportSocket() { |
| 201 } | 197 } |
| 202 | 198 |
| 203 int TransportSocket::Connect(net::OldCompletionCallback* callback) { | |
| 204 // Connect is never called by SSLClientSocket, instead SSLSocketAdapter | |
| 205 // calls Connect() on socket_ directly. | |
| 206 NOTREACHED(); | |
| 207 return false; | |
| 208 } | |
| 209 int TransportSocket::Connect(const net::CompletionCallback& callback) { | 199 int TransportSocket::Connect(const net::CompletionCallback& callback) { |
| 210 // Connect is never called by SSLClientSocket, instead SSLSocketAdapter | 200 // Connect is never called by SSLClientSocket, instead SSLSocketAdapter |
| 211 // calls Connect() on socket_ directly. | 201 // calls Connect() on socket_ directly. |
| 212 NOTREACHED(); | 202 NOTREACHED(); |
| 213 return false; | 203 return false; |
| 214 } | 204 } |
| 215 | 205 |
| 216 void TransportSocket::Disconnect() { | 206 void TransportSocket::Disconnect() { |
| 217 socket_->Close(); | 207 socket_->Close(); |
| 218 } | 208 } |
| (...skipping 62 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 281 NOTREACHED(); | 271 NOTREACHED(); |
| 282 return -1; | 272 return -1; |
| 283 } | 273 } |
| 284 | 274 |
| 285 base::TimeDelta TransportSocket::GetConnectTimeMicros() const { | 275 base::TimeDelta TransportSocket::GetConnectTimeMicros() const { |
| 286 NOTREACHED(); | 276 NOTREACHED(); |
| 287 return base::TimeDelta::FromMicroseconds(-1); | 277 return base::TimeDelta::FromMicroseconds(-1); |
| 288 } | 278 } |
| 289 | 279 |
| 290 int TransportSocket::Read(net::IOBuffer* buf, int buf_len, | 280 int TransportSocket::Read(net::IOBuffer* buf, int buf_len, |
| 291 net::OldCompletionCallback* callback) { | 281 const net::CompletionCallback& callback) { |
| 292 DCHECK(buf); | 282 DCHECK(buf); |
| 293 DCHECK(!old_read_callback_ && read_callback_.is_null()); | 283 DCHECK(read_callback_.is_null()); |
| 294 DCHECK(!read_buffer_.get()); | 284 DCHECK(!read_buffer_.get()); |
| 295 int result = socket_->Recv(buf->data(), buf_len); | 285 int result = socket_->Recv(buf->data(), buf_len); |
| 296 if (result < 0) { | 286 if (result < 0) { |
| 297 result = net::MapSystemError(socket_->GetError()); | |
| 298 if (result == net::ERR_IO_PENDING) { | |
| 299 old_read_callback_ = callback; | |
| 300 read_buffer_ = buf; | |
| 301 read_buffer_len_ = buf_len; | |
| 302 } | |
| 303 } | |
| 304 if (result != net::ERR_IO_PENDING) | |
| 305 was_used_to_convey_data_ = true; | |
| 306 return result; | |
| 307 } | |
| 308 int TransportSocket::Read(net::IOBuffer* buf, int buf_len, | |
| 309 const net::CompletionCallback& callback) { | |
| 310 DCHECK(buf); | |
| 311 DCHECK(!old_read_callback_ && read_callback_.is_null()); | |
| 312 DCHECK(!read_buffer_.get()); | |
| 313 int result = socket_->Recv(buf->data(), buf_len); | |
| 314 if (result < 0) { | |
| 315 result = net::MapSystemError(socket_->GetError()); | 287 result = net::MapSystemError(socket_->GetError()); |
| 316 if (result == net::ERR_IO_PENDING) { | 288 if (result == net::ERR_IO_PENDING) { |
| 317 read_callback_ = callback; | 289 read_callback_ = callback; |
| 318 read_buffer_ = buf; | 290 read_buffer_ = buf; |
| 319 read_buffer_len_ = buf_len; | 291 read_buffer_len_ = buf_len; |
| 320 } | 292 } |
| 321 } | 293 } |
| 322 if (result != net::ERR_IO_PENDING) | 294 if (result != net::ERR_IO_PENDING) |
| 323 was_used_to_convey_data_ = true; | 295 was_used_to_convey_data_ = true; |
| 324 return result; | 296 return result; |
| 325 } | 297 } |
| 326 | 298 |
| 327 int TransportSocket::Write(net::IOBuffer* buf, int buf_len, | 299 int TransportSocket::Write(net::IOBuffer* buf, int buf_len, |
| 328 net::OldCompletionCallback* callback) { | 300 const net::CompletionCallback& callback) { |
| 329 DCHECK(buf); | 301 DCHECK(buf); |
| 330 DCHECK(!write_callback_); | 302 DCHECK(write_callback_.is_null()); |
| 331 DCHECK(!write_buffer_.get()); | 303 DCHECK(!write_buffer_.get()); |
| 332 int result = socket_->Send(buf->data(), buf_len); | 304 int result = socket_->Send(buf->data(), buf_len); |
| 333 if (result < 0) { | 305 if (result < 0) { |
| 334 result = net::MapSystemError(socket_->GetError()); | 306 result = net::MapSystemError(socket_->GetError()); |
| 335 if (result == net::ERR_IO_PENDING) { | 307 if (result == net::ERR_IO_PENDING) { |
| 336 write_callback_ = callback; | 308 write_callback_ = callback; |
| 337 write_buffer_ = buf; | 309 write_buffer_ = buf; |
| 338 write_buffer_len_ = buf_len; | 310 write_buffer_len_ = buf_len; |
| 339 } | 311 } |
| 340 } | 312 } |
| 341 if (result != net::ERR_IO_PENDING) | 313 if (result != net::ERR_IO_PENDING) |
| 342 was_used_to_convey_data_ = true; | 314 was_used_to_convey_data_ = true; |
| 343 return result; | 315 return result; |
| 344 } | 316 } |
| 345 | 317 |
| 346 bool TransportSocket::SetReceiveBufferSize(int32 size) { | 318 bool TransportSocket::SetReceiveBufferSize(int32 size) { |
| 347 // Not implemented. | 319 // Not implemented. |
| 348 return false; | 320 return false; |
| 349 } | 321 } |
| 350 | 322 |
| 351 bool TransportSocket::SetSendBufferSize(int32 size) { | 323 bool TransportSocket::SetSendBufferSize(int32 size) { |
| 352 // Not implemented. | 324 // Not implemented. |
| 353 return false; | 325 return false; |
| 354 } | 326 } |
| 355 | 327 |
| 356 void TransportSocket::OnReadEvent(talk_base::AsyncSocket* socket) { | 328 void TransportSocket::OnReadEvent(talk_base::AsyncSocket* socket) { |
| 357 if (old_read_callback_ || !read_callback_.is_null()) { | 329 if (!read_callback_.is_null()) { |
| 358 DCHECK(read_buffer_.get()); | 330 DCHECK(read_buffer_.get()); |
| 359 net::OldCompletionCallback* old_callback = old_read_callback_; | |
| 360 net::CompletionCallback callback = read_callback_; | 331 net::CompletionCallback callback = read_callback_; |
| 361 scoped_refptr<net::IOBuffer> buffer = read_buffer_; | 332 scoped_refptr<net::IOBuffer> buffer = read_buffer_; |
| 362 int buffer_len = read_buffer_len_; | 333 int buffer_len = read_buffer_len_; |
| 363 | 334 |
| 364 old_read_callback_ = NULL; | |
| 365 read_callback_.Reset(); | 335 read_callback_.Reset(); |
| 366 read_buffer_ = NULL; | 336 read_buffer_ = NULL; |
| 367 read_buffer_len_ = 0; | 337 read_buffer_len_ = 0; |
| 368 | 338 |
| 369 int result = socket_->Recv(buffer->data(), buffer_len); | 339 int result = socket_->Recv(buffer->data(), buffer_len); |
| 370 if (result < 0) { | 340 if (result < 0) { |
| 371 result = net::MapSystemError(socket_->GetError()); | 341 result = net::MapSystemError(socket_->GetError()); |
| 372 if (result == net::ERR_IO_PENDING) { | 342 if (result == net::ERR_IO_PENDING) { |
| 373 old_read_callback_ = old_callback; | |
| 374 read_callback_ = callback; | 343 read_callback_ = callback; |
| 375 read_buffer_ = buffer; | 344 read_buffer_ = buffer; |
| 376 read_buffer_len_ = buffer_len; | 345 read_buffer_len_ = buffer_len; |
| 377 return; | 346 return; |
| 378 } | 347 } |
| 379 } | 348 } |
| 380 was_used_to_convey_data_ = true; | 349 was_used_to_convey_data_ = true; |
| 381 if (old_callback) | 350 callback.Run(result); |
| 382 old_callback->RunWithParams(Tuple1<int>(result)); | |
| 383 else | |
| 384 callback.Run(result); | |
| 385 } | 351 } |
| 386 } | 352 } |
| 387 | 353 |
| 388 void TransportSocket::OnWriteEvent(talk_base::AsyncSocket* socket) { | 354 void TransportSocket::OnWriteEvent(talk_base::AsyncSocket* socket) { |
| 389 if (write_callback_) { | 355 if (!write_callback_.is_null()) { |
| 390 DCHECK(write_buffer_.get()); | 356 DCHECK(write_buffer_.get()); |
| 391 net::OldCompletionCallback* callback = write_callback_; | 357 net::CompletionCallback callback = write_callback_; |
| 392 scoped_refptr<net::IOBuffer> buffer = write_buffer_; | 358 scoped_refptr<net::IOBuffer> buffer = write_buffer_; |
| 393 int buffer_len = write_buffer_len_; | 359 int buffer_len = write_buffer_len_; |
| 394 | 360 |
| 395 write_callback_ = NULL; | 361 write_callback_.Reset(); |
| 396 write_buffer_ = NULL; | 362 write_buffer_ = NULL; |
| 397 write_buffer_len_ = 0; | 363 write_buffer_len_ = 0; |
| 398 | 364 |
| 399 int result = socket_->Send(buffer->data(), buffer_len); | 365 int result = socket_->Send(buffer->data(), buffer_len); |
| 400 if (result < 0) { | 366 if (result < 0) { |
| 401 result = net::MapSystemError(socket_->GetError()); | 367 result = net::MapSystemError(socket_->GetError()); |
| 402 if (result == net::ERR_IO_PENDING) { | 368 if (result == net::ERR_IO_PENDING) { |
| 403 write_callback_ = callback; | 369 write_callback_ = callback; |
| 404 write_buffer_ = buffer; | 370 write_buffer_ = buffer; |
| 405 write_buffer_len_ = buffer_len; | 371 write_buffer_len_ = buffer_len; |
| 406 return; | 372 return; |
| 407 } | 373 } |
| 408 } | 374 } |
| 409 was_used_to_convey_data_ = true; | 375 was_used_to_convey_data_ = true; |
| 410 callback->RunWithParams(Tuple1<int>(result)); | 376 callback.Run(result); |
| 411 } | 377 } |
| 412 } | 378 } |
| 413 | 379 |
| 414 } // namespace remoting | 380 } // namespace remoting |
| OLD | NEW |