OLD | NEW |
| (Empty) |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "remoting/jingle_glue/ssl_socket_adapter.h" | |
6 | |
7 #include "base/base64.h" | |
8 #include "base/compiler_specific.h" | |
9 #include "base/message_loop.h" | |
10 #include "jingle/glue/utils.h" | |
11 #include "net/base/address_list.h" | |
12 #include "net/base/cert_verifier.h" | |
13 #include "net/base/host_port_pair.h" | |
14 #include "net/base/net_errors.h" | |
15 #include "net/base/ssl_config_service.h" | |
16 #include "net/base/transport_security_state.h" | |
17 #include "net/socket/client_socket_factory.h" | |
18 #include "net/url_request/url_request_context.h" | |
19 | |
20 namespace remoting { | |
21 | |
22 SSLSocketAdapter* SSLSocketAdapter::Create(AsyncSocket* socket) { | |
23 return new SSLSocketAdapter(socket); | |
24 } | |
25 | |
26 SSLSocketAdapter::SSLSocketAdapter(AsyncSocket* socket) | |
27 : SSLAdapter(socket), | |
28 ignore_bad_cert_(false), | |
29 cert_verifier_(net::CertVerifier::CreateDefault()), | |
30 transport_security_state_(new net::TransportSecurityState()), | |
31 ssl_state_(SSLSTATE_NONE), | |
32 read_pending_(false), | |
33 write_pending_(false) { | |
34 transport_socket_ = new TransportSocket(socket, this); | |
35 } | |
36 | |
37 SSLSocketAdapter::~SSLSocketAdapter() { | |
38 } | |
39 | |
40 int SSLSocketAdapter::StartSSL(const char* hostname, bool restartable) { | |
41 DCHECK(!restartable); | |
42 hostname_ = hostname; | |
43 | |
44 if (socket_->GetState() != Socket::CS_CONNECTED) { | |
45 ssl_state_ = SSLSTATE_WAIT; | |
46 return 0; | |
47 } else { | |
48 return BeginSSL(); | |
49 } | |
50 } | |
51 | |
52 int SSLSocketAdapter::BeginSSL() { | |
53 if (!MessageLoop::current()) { | |
54 // Certificate verification is done via the Chrome message loop. | |
55 // Without this check, if we don't have a chrome message loop the | |
56 // SSL connection just hangs silently. | |
57 LOG(DFATAL) << "Chrome message loop (needed by SSL certificate " | |
58 << "verification) does not exist"; | |
59 return net::ERR_UNEXPECTED; | |
60 } | |
61 | |
62 // SSLConfigService is not thread-safe, and the default values for SSLConfig | |
63 // are correct for us, so we don't use the config service to initialize this | |
64 // object. | |
65 net::SSLConfig ssl_config; | |
66 net::SSLClientSocketContext context( | |
67 cert_verifier_.get(), NULL, transport_security_state_.get(), ""); | |
68 | |
69 transport_socket_->set_addr(talk_base::SocketAddress(hostname_, 0)); | |
70 ssl_socket_.reset( | |
71 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( | |
72 transport_socket_, net::HostPortPair(hostname_, 443), ssl_config, | |
73 context)); | |
74 | |
75 int result = ssl_socket_->Connect( | |
76 base::Bind(&SSLSocketAdapter::OnConnected, base::Unretained(this))); | |
77 | |
78 if (result == net::ERR_IO_PENDING || result == net::OK) { | |
79 return 0; | |
80 } else { | |
81 LOG(ERROR) << "Could not start SSL: " << net::ErrorToString(result); | |
82 return result; | |
83 } | |
84 } | |
85 | |
86 int SSLSocketAdapter::Send(const void* buf, size_t len) { | |
87 if (ssl_state_ == SSLSTATE_ERROR) { | |
88 SetError(EINVAL); | |
89 return -1; | |
90 } | |
91 | |
92 if (ssl_state_ == SSLSTATE_NONE) { | |
93 // Propagate the call to underlying socket if SSL is not connected | |
94 // yet (connection is not encrypted until StartSSL() is called). | |
95 return AsyncSocketAdapter::Send(buf, len); | |
96 } | |
97 | |
98 if (write_pending_) { | |
99 SetError(EWOULDBLOCK); | |
100 return -1; | |
101 } | |
102 | |
103 write_buffer_ = new net::DrainableIOBuffer(new net::IOBuffer(len), len); | |
104 memcpy(write_buffer_->data(), buf, len); | |
105 | |
106 DoWrite(); | |
107 | |
108 return len; | |
109 } | |
110 | |
111 int SSLSocketAdapter::Recv(void* buf, size_t len) { | |
112 switch (ssl_state_) { | |
113 case SSLSTATE_NONE: { | |
114 return AsyncSocketAdapter::Recv(buf, len); | |
115 } | |
116 | |
117 case SSLSTATE_WAIT: { | |
118 SetError(EWOULDBLOCK); | |
119 return -1; | |
120 } | |
121 | |
122 case SSLSTATE_CONNECTED: { | |
123 if (read_pending_) { | |
124 SetError(EWOULDBLOCK); | |
125 return -1; | |
126 } | |
127 | |
128 int bytes_read = 0; | |
129 | |
130 // Process any data we have left from the previous read. | |
131 if (read_buffer_) { | |
132 int size = std::min(read_buffer_->RemainingCapacity(), | |
133 static_cast<int>(len)); | |
134 memcpy(buf, read_buffer_->data(), size); | |
135 read_buffer_->set_offset(read_buffer_->offset() + size); | |
136 if (!read_buffer_->RemainingCapacity()) | |
137 read_buffer_ = NULL; | |
138 | |
139 if (size == static_cast<int>(len)) | |
140 return size; | |
141 | |
142 // If we didn't fill the caller's buffer then dispatch a new | |
143 // Read() in case there's more data ready. | |
144 buf = reinterpret_cast<char*>(buf) + size; | |
145 len -= size; | |
146 bytes_read = size; | |
147 DCHECK(!read_buffer_); | |
148 } | |
149 | |
150 // Dispatch a Read() request to the SSL layer. | |
151 read_buffer_ = new net::GrowableIOBuffer(); | |
152 read_buffer_->SetCapacity(len); | |
153 int result = ssl_socket_->Read( | |
154 read_buffer_, len, | |
155 base::Bind(&SSLSocketAdapter::OnRead, base::Unretained(this))); | |
156 if (result >= 0) | |
157 memcpy(buf, read_buffer_->data(), len); | |
158 | |
159 if (result == net::ERR_IO_PENDING) { | |
160 read_pending_ = true; | |
161 if (bytes_read) { | |
162 return bytes_read; | |
163 } else { | |
164 SetError(EWOULDBLOCK); | |
165 return -1; | |
166 } | |
167 } | |
168 | |
169 if (result < 0) { | |
170 SetError(EINVAL); | |
171 ssl_state_ = SSLSTATE_ERROR; | |
172 LOG(ERROR) << "Error reading from SSL socket " << result; | |
173 return -1; | |
174 } | |
175 read_buffer_ = NULL; | |
176 return result + bytes_read; | |
177 } | |
178 | |
179 case SSLSTATE_ERROR: { | |
180 SetError(EINVAL); | |
181 return -1; | |
182 } | |
183 } | |
184 | |
185 NOTREACHED(); | |
186 return -1; | |
187 } | |
188 | |
189 void SSLSocketAdapter::OnConnected(int result) { | |
190 if (result == net::OK) { | |
191 ssl_state_ = SSLSTATE_CONNECTED; | |
192 OnConnectEvent(this); | |
193 } else { | |
194 LOG(WARNING) << "OnConnected failed with error " << result; | |
195 } | |
196 } | |
197 | |
198 void SSLSocketAdapter::OnRead(int result) { | |
199 DCHECK(read_pending_); | |
200 read_pending_ = false; | |
201 if (result > 0) { | |
202 DCHECK_GE(read_buffer_->capacity(), result); | |
203 read_buffer_->SetCapacity(result); | |
204 } else { | |
205 if (result < 0) | |
206 ssl_state_ = SSLSTATE_ERROR; | |
207 } | |
208 AsyncSocketAdapter::OnReadEvent(this); | |
209 } | |
210 | |
211 void SSLSocketAdapter::OnWritten(int result) { | |
212 DCHECK(write_pending_); | |
213 write_pending_ = false; | |
214 if (result >= 0) { | |
215 write_buffer_->DidConsume(result); | |
216 if (!write_buffer_->BytesRemaining()) { | |
217 write_buffer_ = NULL; | |
218 } else { | |
219 DoWrite(); | |
220 } | |
221 } else { | |
222 ssl_state_ = SSLSTATE_ERROR; | |
223 } | |
224 AsyncSocketAdapter::OnWriteEvent(this); | |
225 } | |
226 | |
227 void SSLSocketAdapter::DoWrite() { | |
228 DCHECK_GT(write_buffer_->BytesRemaining(), 0); | |
229 DCHECK(!write_pending_); | |
230 | |
231 while (true) { | |
232 int result = ssl_socket_->Write( | |
233 write_buffer_, write_buffer_->BytesRemaining(), | |
234 base::Bind(&SSLSocketAdapter::OnWritten, base::Unretained(this))); | |
235 | |
236 if (result > 0) { | |
237 write_buffer_->DidConsume(result); | |
238 if (!write_buffer_->BytesRemaining()) { | |
239 write_buffer_ = NULL; | |
240 return; | |
241 } | |
242 continue; | |
243 } | |
244 | |
245 if (result == net::ERR_IO_PENDING) { | |
246 write_pending_ = true; | |
247 } else { | |
248 SetError(EINVAL); | |
249 ssl_state_ = SSLSTATE_ERROR; | |
250 } | |
251 return; | |
252 } | |
253 } | |
254 | |
255 void SSLSocketAdapter::OnConnectEvent(talk_base::AsyncSocket* socket) { | |
256 if (ssl_state_ != SSLSTATE_WAIT) { | |
257 AsyncSocketAdapter::OnConnectEvent(socket); | |
258 } else { | |
259 ssl_state_ = SSLSTATE_NONE; | |
260 int result = BeginSSL(); | |
261 if (0 != result) { | |
262 // TODO(zork): Handle this case gracefully. | |
263 LOG(WARNING) << "BeginSSL() failed with " << result; | |
264 } | |
265 } | |
266 } | |
267 | |
268 TransportSocket::TransportSocket(talk_base::AsyncSocket* socket, | |
269 SSLSocketAdapter *ssl_adapter) | |
270 : read_buffer_len_(0), | |
271 write_buffer_len_(0), | |
272 socket_(socket), | |
273 was_used_to_convey_data_(false) { | |
274 socket_->SignalReadEvent.connect(this, &TransportSocket::OnReadEvent); | |
275 socket_->SignalWriteEvent.connect(this, &TransportSocket::OnWriteEvent); | |
276 } | |
277 | |
278 TransportSocket::~TransportSocket() { | |
279 } | |
280 | |
281 int TransportSocket::Connect(const net::CompletionCallback& callback) { | |
282 // Connect is never called by SSLClientSocket, instead SSLSocketAdapter | |
283 // calls Connect() on socket_ directly. | |
284 NOTREACHED(); | |
285 return false; | |
286 } | |
287 | |
288 void TransportSocket::Disconnect() { | |
289 socket_->Close(); | |
290 } | |
291 | |
292 bool TransportSocket::IsConnected() const { | |
293 return (socket_->GetState() == talk_base::Socket::CS_CONNECTED); | |
294 } | |
295 | |
296 bool TransportSocket::IsConnectedAndIdle() const { | |
297 // Not implemented. | |
298 NOTREACHED(); | |
299 return false; | |
300 } | |
301 | |
302 int TransportSocket::GetPeerAddress(net::IPEndPoint* address) const { | |
303 talk_base::SocketAddress socket_address = socket_->GetRemoteAddress(); | |
304 if (jingle_glue::SocketAddressToIPEndPoint(socket_address, address)) { | |
305 return net::OK; | |
306 } else { | |
307 return net::ERR_FAILED; | |
308 } | |
309 } | |
310 | |
311 int TransportSocket::GetLocalAddress(net::IPEndPoint* address) const { | |
312 talk_base::SocketAddress socket_address = socket_->GetLocalAddress(); | |
313 if (jingle_glue::SocketAddressToIPEndPoint(socket_address, address)) { | |
314 return net::OK; | |
315 } else { | |
316 return net::ERR_FAILED; | |
317 } | |
318 } | |
319 | |
320 const net::BoundNetLog& TransportSocket::NetLog() const { | |
321 return net_log_; | |
322 } | |
323 | |
324 void TransportSocket::SetSubresourceSpeculation() { | |
325 NOTREACHED(); | |
326 } | |
327 | |
328 void TransportSocket::SetOmniboxSpeculation() { | |
329 NOTREACHED(); | |
330 } | |
331 | |
332 bool TransportSocket::WasEverUsed() const { | |
333 // We don't use this in ClientSocketPools, so this should never be used. | |
334 NOTREACHED(); | |
335 return was_used_to_convey_data_; | |
336 } | |
337 | |
338 bool TransportSocket::UsingTCPFastOpen() const { | |
339 return false; | |
340 } | |
341 | |
342 int64 TransportSocket::NumBytesRead() const { | |
343 NOTREACHED(); | |
344 return -1; | |
345 } | |
346 | |
347 base::TimeDelta TransportSocket::GetConnectTimeMicros() const { | |
348 NOTREACHED(); | |
349 return base::TimeDelta::FromMicroseconds(-1); | |
350 } | |
351 | |
352 bool TransportSocket::WasNpnNegotiated() const { | |
353 NOTREACHED(); | |
354 return false; | |
355 } | |
356 | |
357 net::NextProto TransportSocket::GetNegotiatedProtocol() const { | |
358 NOTREACHED(); | |
359 return net::kProtoUnknown; | |
360 } | |
361 | |
362 bool TransportSocket::GetSSLInfo(net::SSLInfo* ssl_info) { | |
363 NOTREACHED(); | |
364 return false; | |
365 } | |
366 | |
367 int TransportSocket::Read(net::IOBuffer* buf, int buf_len, | |
368 const net::CompletionCallback& callback) { | |
369 DCHECK(buf); | |
370 DCHECK(read_callback_.is_null()); | |
371 DCHECK(!read_buffer_.get()); | |
372 int result = socket_->Recv(buf->data(), buf_len); | |
373 if (result < 0) { | |
374 result = net::MapSystemError(socket_->GetError()); | |
375 if (result == net::ERR_IO_PENDING) { | |
376 read_callback_ = callback; | |
377 read_buffer_ = buf; | |
378 read_buffer_len_ = buf_len; | |
379 } | |
380 } | |
381 if (result != net::ERR_IO_PENDING) | |
382 was_used_to_convey_data_ = true; | |
383 return result; | |
384 } | |
385 | |
386 int TransportSocket::Write(net::IOBuffer* buf, int buf_len, | |
387 const net::CompletionCallback& callback) { | |
388 DCHECK(buf); | |
389 DCHECK(write_callback_.is_null()); | |
390 DCHECK(!write_buffer_.get()); | |
391 int result = socket_->Send(buf->data(), buf_len); | |
392 if (result < 0) { | |
393 result = net::MapSystemError(socket_->GetError()); | |
394 if (result == net::ERR_IO_PENDING) { | |
395 write_callback_ = callback; | |
396 write_buffer_ = buf; | |
397 write_buffer_len_ = buf_len; | |
398 } | |
399 } | |
400 if (result != net::ERR_IO_PENDING) | |
401 was_used_to_convey_data_ = true; | |
402 return result; | |
403 } | |
404 | |
405 bool TransportSocket::SetReceiveBufferSize(int32 size) { | |
406 // Not implemented. | |
407 return false; | |
408 } | |
409 | |
410 bool TransportSocket::SetSendBufferSize(int32 size) { | |
411 // Not implemented. | |
412 return false; | |
413 } | |
414 | |
415 void TransportSocket::OnReadEvent(talk_base::AsyncSocket* socket) { | |
416 if (!read_callback_.is_null()) { | |
417 DCHECK(read_buffer_.get()); | |
418 net::CompletionCallback callback = read_callback_; | |
419 scoped_refptr<net::IOBuffer> buffer = read_buffer_; | |
420 int buffer_len = read_buffer_len_; | |
421 | |
422 read_callback_.Reset(); | |
423 read_buffer_ = NULL; | |
424 read_buffer_len_ = 0; | |
425 | |
426 int result = socket_->Recv(buffer->data(), buffer_len); | |
427 if (result < 0) { | |
428 result = net::MapSystemError(socket_->GetError()); | |
429 if (result == net::ERR_IO_PENDING) { | |
430 read_callback_ = callback; | |
431 read_buffer_ = buffer; | |
432 read_buffer_len_ = buffer_len; | |
433 return; | |
434 } | |
435 } | |
436 was_used_to_convey_data_ = true; | |
437 callback.Run(result); | |
438 } | |
439 } | |
440 | |
441 void TransportSocket::OnWriteEvent(talk_base::AsyncSocket* socket) { | |
442 if (!write_callback_.is_null()) { | |
443 DCHECK(write_buffer_.get()); | |
444 net::CompletionCallback callback = write_callback_; | |
445 scoped_refptr<net::IOBuffer> buffer = write_buffer_; | |
446 int buffer_len = write_buffer_len_; | |
447 | |
448 write_callback_.Reset(); | |
449 write_buffer_ = NULL; | |
450 write_buffer_len_ = 0; | |
451 | |
452 int result = socket_->Send(buffer->data(), buffer_len); | |
453 if (result < 0) { | |
454 result = net::MapSystemError(socket_->GetError()); | |
455 if (result == net::ERR_IO_PENDING) { | |
456 write_callback_ = callback; | |
457 write_buffer_ = buffer; | |
458 write_buffer_len_ = buffer_len; | |
459 return; | |
460 } | |
461 } | |
462 was_used_to_convey_data_ = true; | |
463 callback.Run(result); | |
464 } | |
465 } | |
466 | |
467 } // namespace remoting | |
OLD | NEW |