| OLD | NEW |
| (Empty) |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
| 2 // Use of this source code is governed by a BSD-style license that can be | |
| 3 // found in the LICENSE file. | |
| 4 | |
| 5 #include "remoting/jingle_glue/ssl_socket_adapter.h" | |
| 6 | |
| 7 #include "base/base64.h" | |
| 8 #include "base/compiler_specific.h" | |
| 9 #include "base/message_loop.h" | |
| 10 #include "jingle/glue/utils.h" | |
| 11 #include "net/base/address_list.h" | |
| 12 #include "net/base/cert_verifier.h" | |
| 13 #include "net/base/host_port_pair.h" | |
| 14 #include "net/base/net_errors.h" | |
| 15 #include "net/base/ssl_config_service.h" | |
| 16 #include "net/base/transport_security_state.h" | |
| 17 #include "net/socket/client_socket_factory.h" | |
| 18 #include "net/url_request/url_request_context.h" | |
| 19 | |
| 20 namespace remoting { | |
| 21 | |
| 22 SSLSocketAdapter* SSLSocketAdapter::Create(AsyncSocket* socket) { | |
| 23 return new SSLSocketAdapter(socket); | |
| 24 } | |
| 25 | |
| 26 SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket) | |
| 27 : SSLAdapter(socket), | |
| 28 ignore_bad_cert_(false), | |
| 29 cert_verifier_(net::CertVerifier::CreateDefault()), | |
| 30 transport_security_state_(new net::TransportSecurityState()), | |
| 31 ssl_state_(SSLSTATE_NONE), | |
| 32 read_pending_(false), | |
| 33 write_pending_(false) { | |
| 34 transport_socket_ = new TransportSocket(socket, this); | |
| 35 } | |
| 36 | |
| 37 SSLSocketAdapter::~SSLSocketAdapter() { | |
| 38 } | |
| 39 | |
| 40 int SSLSocketAdapter::StartSSL(const char* hostname, bool restartable) { | |
| 41 DCHECK(!restartable); | |
| 42 hostname_ = hostname; | |
| 43 | |
| 44 if (socket_->GetState() != Socket::CS_CONNECTED) { | |
| 45 ssl_state_ = SSLSTATE_WAIT; | |
| 46 return 0; | |
| 47 } else { | |
| 48 return BeginSSL(); | |
| 49 } | |
| 50 } | |
| 51 | |
| 52 int SSLSocketAdapter::BeginSSL() { | |
| 53 if (!MessageLoop::current()) { | |
| 54 // Certificate verification is done via the Chrome message loop. | |
| 55 // Without this check, if we don't have a chrome message loop the | |
| 56 // SSL connection just hangs silently. | |
| 57 LOG(DFATAL) << "Chrome message loop (needed by SSL certificate " | |
| 58 << "verification) does not exist"; | |
| 59 return net::ERR_UNEXPECTED; | |
| 60 } | |
| 61 | |
| 62 // SSLConfigService is not thread-safe, and the default values for SSLConfig | |
| 63 // are correct for us, so we don't use the config service to initialize this | |
| 64 // object. | |
| 65 net::SSLConfig ssl_config; | |
| 66 net::SSLClientSocketContext context( | |
| 67 cert_verifier_.get(), NULL, transport_security_state_.get(), ""); | |
| 68 | |
| 69 transport_socket_->set_addr(talk_base::SocketAddress(hostname_, 0)); | |
| 70 ssl_socket_.reset( | |
| 71 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( | |
| 72 transport_socket_, net::HostPortPair(hostname_, 443), ssl_config, | |
| 73 context)); | |
| 74 | |
| 75 int result = ssl_socket_->Connect( | |
| 76 base::Bind(&SSLSocketAdapter::OnConnected, base::Unretained(this))); | |
| 77 | |
| 78 if (result == net::ERR_IO_PENDING || result == net::OK) { | |
| 79 return 0; | |
| 80 } else { | |
| 81 LOG(ERROR) << "Could not start SSL: " << net::ErrorToString(result); | |
| 82 return result; | |
| 83 } | |
| 84 } | |
| 85 | |
| 86 int SSLSocketAdapter::Send(const void* buf, size_t len) { | |
| 87 if (ssl_state_ == SSLSTATE_ERROR) { | |
| 88 SetError(EINVAL); | |
| 89 return -1; | |
| 90 } | |
| 91 | |
| 92 if (ssl_state_ == SSLSTATE_NONE) { | |
| 93 // Propagate the call to underlying socket if SSL is not connected | |
| 94 // yet (connection is not encrypted until StartSSL() is called). | |
| 95 return AsyncSocketAdapter::Send(buf, len); | |
| 96 } | |
| 97 | |
| 98 if (write_pending_) { | |
| 99 SetError(EWOULDBLOCK); | |
| 100 return -1; | |
| 101 } | |
| 102 | |
| 103 write_buffer_ = new net::DrainableIOBuffer(new net::IOBuffer(len), len); | |
| 104 memcpy(write_buffer_->data(), buf, len); | |
| 105 | |
| 106 DoWrite(); | |
| 107 | |
| 108 return len; | |
| 109 } | |
| 110 | |
| 111 int SSLSocketAdapter::Recv(void* buf, size_t len) { | |
| 112 switch (ssl_state_) { | |
| 113 case SSLSTATE_NONE: { | |
| 114 return AsyncSocketAdapter::Recv(buf, len); | |
| 115 } | |
| 116 | |
| 117 case SSLSTATE_WAIT: { | |
| 118 SetError(EWOULDBLOCK); | |
| 119 return -1; | |
| 120 } | |
| 121 | |
| 122 case SSLSTATE_CONNECTED: { | |
| 123 if (read_pending_) { | |
| 124 SetError(EWOULDBLOCK); | |
| 125 return -1; | |
| 126 } | |
| 127 | |
| 128 int bytes_read = 0; | |
| 129 | |
| 130 // Process any data we have left from the previous read. | |
| 131 if (read_buffer_) { | |
| 132 int size = std::min(read_buffer_->RemainingCapacity(), | |
| 133 static_cast<int>(len)); | |
| 134 memcpy(buf, read_buffer_->data(), size); | |
| 135 read_buffer_->set_offset(read_buffer_->offset() + size); | |
| 136 if (!read_buffer_->RemainingCapacity()) | |
| 137 read_buffer_ = NULL; | |
| 138 | |
| 139 if (size == static_cast<int>(len)) | |
| 140 return size; | |
| 141 | |
| 142 // If we didn't fill the caller's buffer then dispatch a new | |
| 143 // Read() in case there's more data ready. | |
| 144 buf = reinterpret_cast<char*>(buf) + size; | |
| 145 len -= size; | |
| 146 bytes_read = size; | |
| 147 DCHECK(!read_buffer_); | |
| 148 } | |
| 149 | |
| 150 // Dispatch a Read() request to the SSL layer. | |
| 151 read_buffer_ = new net::GrowableIOBuffer(); | |
| 152 read_buffer_->SetCapacity(len); | |
| 153 int result = ssl_socket_->Read( | |
| 154 read_buffer_, len, | |
| 155 base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this))); | |
| 156 if (result >= 0) | |
| 157 memcpy(buf, read_buffer_->data(), len); | |
| 158 | |
| 159 if (result == net::ERR_IO_PENDING) { | |
| 160 read_pending_ = true; | |
| 161 if (bytes_read) { | |
| 162 return bytes_read; | |
| 163 } else { | |
| 164 SetError(EWOULDBLOCK); | |
| 165 return -1; | |
| 166 } | |
| 167 } | |
| 168 | |
| 169 if (result < 0) { | |
| 170 SetError(EINVAL); | |
| 171 ssl_state_ = SSLSTATE_ERROR; | |
| 172 LOG(ERROR) << "Error reading from SSL socket " << result; | |
| 173 return -1; | |
| 174 } | |
| 175 read_buffer_ = NULL; | |
| 176 return result + bytes_read; | |
| 177 } | |
| 178 | |
| 179 case SSLSTATE_ERROR: { | |
| 180 SetError(EINVAL); | |
| 181 return -1; | |
| 182 } | |
| 183 } | |
| 184 | |
| 185 NOTREACHED(); | |
| 186 return -1; | |
| 187 } | |
| 188 | |
| 189 void SSLSocketAdapter::OnConnected(int result) { | |
| 190 if (result == net::OK) { | |
| 191 ssl_state_ = SSLSTATE_CONNECTED; | |
| 192 OnConnectEvent(this); | |
| 193 } else { | |
| 194 LOG(WARNING) << "OnConnected failed with error " << result; | |
| 195 } | |
| 196 } | |
| 197 | |
| 198 void SSLSocketAdapter::OnRead(int result) { | |
| 199 DCHECK(read_pending_); | |
| 200 read_pending_ = false; | |
| 201 if (result > 0) { | |
| 202 DCHECK_GE(read_buffer_->capacity(), result); | |
| 203 read_buffer_->SetCapacity(result); | |
| 204 } else { | |
| 205 if (result < 0) | |
| 206 ssl_state_ = SSLSTATE_ERROR; | |
| 207 } | |
| 208 AsyncSocketAdapter::OnReadEvent(this); | |
| 209 } | |
| 210 | |
| 211 void SSLSocketAdapter::OnWritten(int result) { | |
| 212 DCHECK(write_pending_); | |
| 213 write_pending_ = false; | |
| 214 if (result >= 0) { | |
| 215 write_buffer_->DidConsume(result); | |
| 216 if (!write_buffer_->BytesRemaining()) { | |
| 217 write_buffer_ = NULL; | |
| 218 } else { | |
| 219 DoWrite(); | |
| 220 } | |
| 221 } else { | |
| 222 ssl_state_ = SSLSTATE_ERROR; | |
| 223 } | |
| 224 AsyncSocketAdapter::OnWriteEvent(this); | |
| 225 } | |
| 226 | |
| 227 void SSLSocketAdapter::DoWrite() { | |
| 228 DCHECK_GT(write_buffer_->BytesRemaining(), 0); | |
| 229 DCHECK(!write_pending_); | |
| 230 | |
| 231 while (true) { | |
| 232 int result = ssl_socket_->Write( | |
| 233 write_buffer_, write_buffer_->BytesRemaining(), | |
| 234 base::Bind(&SSLSocketAdapter::OnWritten, base::Unretained(this))); | |
| 235 | |
| 236 if (result > 0) { | |
| 237 write_buffer_->DidConsume(result); | |
| 238 if (!write_buffer_->BytesRemaining()) { | |
| 239 write_buffer_ = NULL; | |
| 240 return; | |
| 241 } | |
| 242 continue; | |
| 243 } | |
| 244 | |
| 245 if (result == net::ERR_IO_PENDING) { | |
| 246 write_pending_ = true; | |
| 247 } else { | |
| 248 SetError(EINVAL); | |
| 249 ssl_state_ = SSLSTATE_ERROR; | |
| 250 } | |
| 251 return; | |
| 252 } | |
| 253 } | |
| 254 | |
| 255 void SSLSocketAdapter::OnConnectEvent(talk_base::AsyncSocket* socket) { | |
| 256 if (ssl_state_ != SSLSTATE_WAIT) { | |
| 257 AsyncSocketAdapter::OnConnectEvent(socket); | |
| 258 } else { | |
| 259 ssl_state_ = SSLSTATE_NONE; | |
| 260 int result = BeginSSL(); | |
| 261 if (0 != result) { | |
| 262 // TODO(zork): Handle this case gracefully. | |
| 263 LOG(WARNING) << "BeginSSL() failed with " << result; | |
| 264 } | |
| 265 } | |
| 266 } | |
| 267 | |
| 268 TransportSocket::TransportSocket(talk_base::AsyncSocket* socket, | |
| 269 SSLSocketAdapter *ssl_adapter) | |
| 270 : read_buffer_len_(0), | |
| 271 write_buffer_len_(0), | |
| 272 socket_(socket), | |
| 273 was_used_to_convey_data_(false) { | |
| 274 socket_->SignalReadEvent.connect(this, &TransportSocket::OnReadEvent); | |
| 275 socket_->SignalWriteEvent.connect(this, &TransportSocket::OnWriteEvent); | |
| 276 } | |
| 277 | |
| 278 TransportSocket::~TransportSocket() { | |
| 279 } | |
| 280 | |
| 281 int TransportSocket::Connect(const net::CompletionCallback& callback) { | |
| 282 // Connect is never called by SSLClientSocket, instead SSLSocketAdapter | |
| 283 // calls Connect() on socket_ directly. | |
| 284 NOTREACHED(); | |
| 285 return false; | |
| 286 } | |
| 287 | |
| 288 void TransportSocket::Disconnect() { | |
| 289 socket_->Close(); | |
| 290 } | |
| 291 | |
| 292 bool TransportSocket::IsConnected() const { | |
| 293 return (socket_->GetState() == talk_base::Socket::CS_CONNECTED); | |
| 294 } | |
| 295 | |
| 296 bool TransportSocket::IsConnectedAndIdle() const { | |
| 297 // Not implemented. | |
| 298 NOTREACHED(); | |
| 299 return false; | |
| 300 } | |
| 301 | |
| 302 int TransportSocket::GetPeerAddress(net::IPEndPoint* address) const { | |
| 303 talk_base::SocketAddress socket_address = socket_->GetRemoteAddress(); | |
| 304 if (jingle_glue::SocketAddressToIPEndPoint(socket_address, address)) { | |
| 305 return net::OK; | |
| 306 } else { | |
| 307 return net::ERR_FAILED; | |
| 308 } | |
| 309 } | |
| 310 | |
| 311 int TransportSocket::GetLocalAddress(net::IPEndPoint* address) const { | |
| 312 talk_base::SocketAddress socket_address = socket_->GetLocalAddress(); | |
| 313 if (jingle_glue::SocketAddressToIPEndPoint(socket_address, address)) { | |
| 314 return net::OK; | |
| 315 } else { | |
| 316 return net::ERR_FAILED; | |
| 317 } | |
| 318 } | |
| 319 | |
| 320 const net::BoundNetLog& TransportSocket::NetLog() const { | |
| 321 return net_log_; | |
| 322 } | |
| 323 | |
| 324 void TransportSocket::SetSubresourceSpeculation() { | |
| 325 NOTREACHED(); | |
| 326 } | |
| 327 | |
| 328 void TransportSocket::SetOmniboxSpeculation() { | |
| 329 NOTREACHED(); | |
| 330 } | |
| 331 | |
| 332 bool TransportSocket::WasEverUsed() const { | |
| 333 // We don't use this in ClientSocketPools, so this should never be used. | |
| 334 NOTREACHED(); | |
| 335 return was_used_to_convey_data_; | |
| 336 } | |
| 337 | |
| 338 bool TransportSocket::UsingTCPFastOpen() const { | |
| 339 return false; | |
| 340 } | |
| 341 | |
| 342 int64 TransportSocket::NumBytesRead() const { | |
| 343 NOTREACHED(); | |
| 344 return -1; | |
| 345 } | |
| 346 | |
| 347 base::TimeDelta TransportSocket::GetConnectTimeMicros() const { | |
| 348 NOTREACHED(); | |
| 349 return base::TimeDelta::FromMicroseconds(-1); | |
| 350 } | |
| 351 | |
| 352 bool TransportSocket::WasNpnNegotiated() const { | |
| 353 NOTREACHED(); | |
| 354 return false; | |
| 355 } | |
| 356 | |
| 357 net::NextProto TransportSocket::GetNegotiatedProtocol() const { | |
| 358 NOTREACHED(); | |
| 359 return net::kProtoUnknown; | |
| 360 } | |
| 361 | |
| 362 bool TransportSocket::GetSSLInfo(net::SSLInfo* ssl_info) { | |
| 363 NOTREACHED(); | |
| 364 return false; | |
| 365 } | |
| 366 | |
| 367 int TransportSocket::Read(net::IOBuffer* buf, int buf_len, | |
| 368 const net::CompletionCallback& callback) { | |
| 369 DCHECK(buf); | |
| 370 DCHECK(read_callback_.is_null()); | |
| 371 DCHECK(!read_buffer_.get()); | |
| 372 int result = socket_->Recv(buf->data(), buf_len); | |
| 373 if (result < 0) { | |
| 374 result = net::MapSystemError(socket_->GetError()); | |
| 375 if (result == net::ERR_IO_PENDING) { | |
| 376 read_callback_ = callback; | |
| 377 read_buffer_ = buf; | |
| 378 read_buffer_len_ = buf_len; | |
| 379 } | |
| 380 } | |
| 381 if (result != net::ERR_IO_PENDING) | |
| 382 was_used_to_convey_data_ = true; | |
| 383 return result; | |
| 384 } | |
| 385 | |
| 386 int TransportSocket::Write(net::IOBuffer* buf, int buf_len, | |
| 387 const net::CompletionCallback& callback) { | |
| 388 DCHECK(buf); | |
| 389 DCHECK(write_callback_.is_null()); | |
| 390 DCHECK(!write_buffer_.get()); | |
| 391 int result = socket_->Send(buf->data(), buf_len); | |
| 392 if (result < 0) { | |
| 393 result = net::MapSystemError(socket_->GetError()); | |
| 394 if (result == net::ERR_IO_PENDING) { | |
| 395 write_callback_ = callback; | |
| 396 write_buffer_ = buf; | |
| 397 write_buffer_len_ = buf_len; | |
| 398 } | |
| 399 } | |
| 400 if (result != net::ERR_IO_PENDING) | |
| 401 was_used_to_convey_data_ = true; | |
| 402 return result; | |
| 403 } | |
| 404 | |
| 405 bool TransportSocket::SetReceiveBufferSize(int32 size) { | |
| 406 // Not implemented. | |
| 407 return false; | |
| 408 } | |
| 409 | |
| 410 bool TransportSocket::SetSendBufferSize(int32 size) { | |
| 411 // Not implemented. | |
| 412 return false; | |
| 413 } | |
| 414 | |
| 415 void TransportSocket::OnReadEvent(talk_base::AsyncSocket* socket) { | |
| 416 if (!read_callback_.is_null()) { | |
| 417 DCHECK(read_buffer_.get()); | |
| 418 net::CompletionCallback callback = read_callback_; | |
| 419 scoped_refptr<net::IOBuffer> buffer = read_buffer_; | |
| 420 int buffer_len = read_buffer_len_; | |
| 421 | |
| 422 read_callback_.Reset(); | |
| 423 read_buffer_ = NULL; | |
| 424 read_buffer_len_ = 0; | |
| 425 | |
| 426 int result = socket_->Recv(buffer->data(), buffer_len); | |
| 427 if (result < 0) { | |
| 428 result = net::MapSystemError(socket_->GetError()); | |
| 429 if (result == net::ERR_IO_PENDING) { | |
| 430 read_callback_ = callback; | |
| 431 read_buffer_ = buffer; | |
| 432 read_buffer_len_ = buffer_len; | |
| 433 return; | |
| 434 } | |
| 435 } | |
| 436 was_used_to_convey_data_ = true; | |
| 437 callback.Run(result); | |
| 438 } | |
| 439 } | |
| 440 | |
| 441 void TransportSocket::OnWriteEvent(talk_base::AsyncSocket* socket) { | |
| 442 if (!write_callback_.is_null()) { | |
| 443 DCHECK(write_buffer_.get()); | |
| 444 net::CompletionCallback callback = write_callback_; | |
| 445 scoped_refptr<net::IOBuffer> buffer = write_buffer_; | |
| 446 int buffer_len = write_buffer_len_; | |
| 447 | |
| 448 write_callback_.Reset(); | |
| 449 write_buffer_ = NULL; | |
| 450 write_buffer_len_ = 0; | |
| 451 | |
| 452 int result = socket_->Send(buffer->data(), buffer_len); | |
| 453 if (result < 0) { | |
| 454 result = net::MapSystemError(socket_->GetError()); | |
| 455 if (result == net::ERR_IO_PENDING) { | |
| 456 write_callback_ = callback; | |
| 457 write_buffer_ = buffer; | |
| 458 write_buffer_len_ = buffer_len; | |
| 459 return; | |
| 460 } | |
| 461 } | |
| 462 was_used_to_convey_data_ = true; | |
| 463 callback.Run(result); | |
| 464 } | |
| 465 } | |
| 466 | |
| 467 } // namespace remoting | |
| OLD | NEW |