Index: net/socket/ssl_client_socket_openssl.cc |
diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc |
index 4ff8d438e965b1fdf0b5b7f8b4ecdb75e1d14d66..d6899f8a5e1e64b4c5582a659942028217a4b06c 100644 |
--- a/net/socket/ssl_client_socket_openssl.cc |
+++ b/net/socket/ssl_client_socket_openssl.cc |
@@ -23,6 +23,7 @@ |
#include "net/cert/single_request_cert_verifier.h" |
#include "net/cert/x509_certificate_net_log_param.h" |
#include "net/socket/openssl_ssl_util.h" |
+#include "net/socket/ssl_client_socket_pool.h" |
wtc
2014/07/15 19:27:59
We should be able to remove this now.
mshelley
2014/07/17 00:28:45
Done.
|
#include "net/socket/ssl_error_params.h" |
#include "net/socket/ssl_session_cache_openssl.h" |
#include "net/ssl/openssl_client_key_store.h" |
@@ -87,14 +88,6 @@ int GetNetSSLVersion(SSL* ssl) { |
} |
} |
-// Compute a unique key string for the SSL session cache. |socket| is an |
-// input socket object. Return a string. |
-std::string GetSocketSessionCacheKey(const SSLClientSocketOpenSSL& socket) { |
- std::string result = socket.host_and_port().ToString(); |
- result.append("/"); |
- result.append(socket.ssl_session_cache_shard()); |
- return result; |
-} |
} // namespace |
@@ -139,7 +132,7 @@ class SSLClientSocketOpenSSL::SSLContext { |
static std::string GetSessionCacheKey(const SSL* ssl) { |
SSLClientSocketOpenSSL* socket = GetInstance()->GetClientSocketFromSSL(ssl); |
DCHECK(socket); |
- return GetSocketSessionCacheKey(*socket); |
+ return socket->GetSessionCacheKey(); |
} |
static SSLSessionCacheOpenSSL::Config kDefaultSessionCacheConfig; |
@@ -360,12 +353,49 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( |
npn_status_(kNextProtoUnsupported), |
channel_id_request_return_value_(ERR_UNEXPECTED), |
channel_id_xtn_negotiated_(false), |
- net_log_(transport_->socket()->NetLog()) {} |
+ net_log_(transport_->socket()->NetLog()) { |
+} |
SSLClientSocketOpenSSL::~SSLClientSocketOpenSSL() { |
Disconnect(); |
} |
+std::string SSLClientSocketOpenSSL::GetSessionCacheKey() const { |
+ return FormatSessionCacheKey(host_and_port_.ToString(), |
+ ssl_session_cache_shard_); |
+} |
+ |
+bool SSLClientSocketOpenSSL::InSessionCache() const { |
+ SSLContext* context = SSLContext::GetInstance(); |
+ std::string cache_key = GetSessionCacheKey(); |
+ return context->session_cache()->SSLSessionIsInCache(cache_key); |
+} |
+ |
+void SSLClientSocketOpenSSL::SetHandshakeSuccessCallback( |
+ const base::Closure& callback) { |
+ success_callback_ = callback; |
+ SSLContext* context = SSLContext::GetInstance(); |
+ context->session_cache()->SetSessionAddedCallback( |
+ ssl_, |
+ base::Bind(&SSLClientSocketOpenSSL::OnHandshakeSuccess, |
+ base::Unretained(this))); |
+} |
+ |
+void SSLClientSocketOpenSSL::SetHandshakeFailureCallback( |
+ const base::Closure& callback) { |
+ error_callback_ = callback; |
+} |
+ |
+void SSLClientSocketOpenSSL::OnHandshakeSuccess() { |
+ error_callback_.Reset(); |
+ base::ResetAndReturn(&success_callback_).Run(); |
+} |
+ |
+void SSLClientSocketOpenSSL::OnHandshakeFailure() { |
+ if (!error_callback_.is_null()) |
+ base::ResetAndReturn(&error_callback_).Run(); |
+} |
+ |
void SSLClientSocketOpenSSL::GetSSLCertRequestInfo( |
SSLCertRequestInfo* cert_request_info) { |
cert_request_info->host_and_port = host_and_port_; |
@@ -436,6 +466,9 @@ int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { |
} |
void SSLClientSocketOpenSSL::Disconnect() { |
+ OnHandshakeFailure(); |
+ SSLContext* context = SSLContext::GetInstance(); |
+ context->session_cache()->RemoveSessionAddedCallback(ssl_); |
wtc
2014/07/15 19:27:58
Move these two lines into the if (ssl_) statement
mshelley
2014/07/17 00:28:45
Done.
|
if (ssl_) { |
// Calling SSL_shutdown prevents the session from being marked as |
// unresumable. |
@@ -658,7 +691,7 @@ int SSLClientSocketOpenSSL::Init() { |
return ERR_UNEXPECTED; |
trying_cached_session_ = context->session_cache()->SetSSLSessionWithKey( |
- ssl_, GetSocketSessionCacheKey(*this)); |
+ ssl_, GetSessionCacheKey()); |
BIO* ssl_bio = NULL; |
// 0 => use default buffer sizes. |
@@ -795,6 +828,9 @@ int SSLClientSocketOpenSSL::DoHandshake() { |
int net_error = OK; |
int rv = SSL_do_handshake(ssl_); |
+ if (rv <= 0) |
+ OnHandshakeFailure(); |
wtc
2014/07/15 19:27:59
Delete this. I think it's better to do this after
mshelley
2014/07/17 00:28:45
Done.
|
+ |
if (client_auth_cert_needed_) { |
net_error = ERR_SSL_CLIENT_AUTH_CERT_NEEDED; |
// If the handshake already succeeded (because the server requests but |
@@ -1064,6 +1100,10 @@ int SSLClientSocketOpenSSL::DoPayloadRead() { |
do { |
rv = SSL_read(ssl_, user_read_buf_->data() + total_bytes_read, |
user_read_buf_len_ - total_bytes_read); |
+ // Failure of the first read attempt indicates a failed false start |
+ // connection. |
wtc
2014/07/15 19:27:59
Update this comment.
mshelley
2014/07/17 00:28:45
Done.
|
+ if (rv <= OK) |
wtc
2014/07/15 19:27:59
This should be
if (rv <= 0)
You can combine t
mshelley
2014/07/17 00:28:45
Done.
|
+ OnHandshakeFailure(); |
if (rv > 0) |
total_bytes_read += rv; |
} while (total_bytes_read < user_read_buf_len_ && rv > 0); |
@@ -1116,7 +1156,11 @@ int SSLClientSocketOpenSSL::DoPayloadRead() { |
int SSLClientSocketOpenSSL::DoPayloadWrite() { |
crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); |
int rv = SSL_write(ssl_, user_write_buf_->data(), user_write_buf_len_); |
- |
+ // Failure of the second write attempt indicates a failed false start |
+ // connection. |
+ if (rv <= 0) { |
wtc
2014/07/15 19:27:59
This should be
if (rv < 0)
You can just call On
mshelley
2014/07/17 00:28:45
Done.
|
+ OnHandshakeFailure(); |
+ } |
wtc
2014/07/15 19:27:59
Omit the curly braces.
mshelley
2014/07/17 00:28:45
Done.
|
if (rv >= 0) { |
net_log_.AddByteTransferEvent(NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv, |
user_write_buf_->data()); |