Index: net/socket/ssl_client_socket_openssl.cc |
diff --git a/net/socket/ssl_client_socket_openssl.cc b/net/socket/ssl_client_socket_openssl.cc |
index f933031539b007fb3dd02cd4616538b8f065d050..547de444b415611d1ebc64481263efcda89ed0ce 100644 |
--- a/net/socket/ssl_client_socket_openssl.cc |
+++ b/net/socket/ssl_client_socket_openssl.cc |
@@ -200,28 +200,32 @@ int NoOpVerifyCallback(X509_STORE_CTX*, void *) { |
// OpenSSL manages a cache of SSL_SESSION, this class provides the application |
// side policy for that cache about session re-use: we retain one session per |
-// unique HostPortPair. |
+// unique HostPortPair, per shard. |
class SSLSessionCache { |
public: |
SSLSessionCache() {} |
- void OnSessionAdded(const HostPortPair& host_and_port, SSL_SESSION* session) { |
+ void OnSessionAdded(const HostPortPair& host_and_port, |
+ const std::string& shard, |
+ SSL_SESSION* session) { |
// Declare the session cleaner-upper before the lock, so any call into |
// OpenSSL to free the session will happen after the lock is released. |
crypto::ScopedOpenSSL<SSL_SESSION, SSL_SESSION_free> session_to_free; |
base::AutoLock lock(lock_); |
DCHECK_EQ(0U, session_map_.count(session)); |
+ const std::string cache_key = GetCacheKey(host_and_port, shard); |
+ |
std::pair<HostPortMap::iterator, bool> res = |
- host_port_map_.insert(std::make_pair(host_and_port, session)); |
+ host_port_map_.insert(std::make_pair(cache_key, session)); |
if (!res.second) { // Already exists: replace old entry. |
session_to_free.reset(res.first->second); |
session_map_.erase(session_to_free.get()); |
res.first->second = session; |
} |
DVLOG(2) << "Adding session " << session << " => " |
- << host_and_port.ToString() << ", new entry = " << res.second; |
- DCHECK(host_port_map_[host_and_port] == session); |
+ << cache_key << ", new entry = " << res.second; |
+ DCHECK(host_port_map_[cache_key] == session); |
session_map_[session] = res.first; |
DCHECK_EQ(host_port_map_.size(), session_map_.size()); |
DCHECK_LE(host_port_map_.size(), kSessionCacheMaxEntires); |
@@ -236,8 +240,7 @@ class SSLSessionCache { |
SessionMap::iterator it = session_map_.find(session); |
if (it == session_map_.end()) |
return; |
- DVLOG(2) << "Remove session " << session << " => " |
- << it->second->first.ToString(); |
+ DVLOG(2) << "Remove session " << session << " => " << it->second->first; |
DCHECK(it->second->second == session); |
host_port_map_.erase(it->second); |
session_map_.erase(it); |
@@ -247,13 +250,14 @@ class SSLSessionCache { |
// Looks up the host:port in the cache, and if a session is found it is added |
// to |ssl|, returning true on success. |
- bool SetSSLSession(SSL* ssl, const HostPortPair& host_and_port) { |
+ bool SetSSLSession(SSL* ssl, const HostPortPair& host_and_port, |
+ const std::string& shard) { |
base::AutoLock lock(lock_); |
- HostPortMap::iterator it = host_port_map_.find(host_and_port); |
+ const std::string cache_key = GetCacheKey(host_and_port, shard); |
+ HostPortMap::iterator it = host_port_map_.find(cache_key); |
if (it == host_port_map_.end()) |
return false; |
- DVLOG(2) << "Lookup session: " << it->second << " => " |
- << host_and_port.ToString(); |
+ DVLOG(2) << "Lookup session: " << it->second << " => " << cache_key; |
SSL_SESSION* session = it->second; |
DCHECK(session); |
DCHECK(session_map_[session] == it); |
@@ -265,12 +269,26 @@ class SSLSessionCache { |
return SSL_set_session(ssl, session) == 1; |
} |
+ // Flush removes all entries from the cache. This is called when a client |
+ // certificate is added. |
+ void Flush() { |
+ for (HostPortMap::iterator i = host_port_map_.begin(); |
+ i != host_port_map_.end(); i++) { |
+ SSL_SESSION_free(i->second); |
+ } |
+ host_port_map_.clear(); |
+ session_map_.clear(); |
+ } |
+ |
private: |
+ static std::string GetCacheKey(const HostPortPair& host_and_port, |
+ const std::string& shard) { |
+ return host_and_port.ToString() + "/" + shard; |
+ } |
+ |
// A pair of maps to allow bi-directional lookups between host:port and an |
// associated session. |
- // TODO(joth): When client certificates are implemented we should key the |
- // cache on the client certificate used in addition to the host-port pair. |
- typedef std::map<HostPortPair, SSL_SESSION*> HostPortMap; |
+ typedef std::map<std::string, SSL_SESSION*> HostPortMap; |
typedef std::map<SSL_SESSION*, HostPortMap::iterator> SessionMap; |
HostPortMap host_port_map_; |
SessionMap session_map_; |
@@ -329,7 +347,9 @@ class SSLContext { |
int NewSessionCallback(SSL* ssl, SSL_SESSION* session) { |
SSLClientSocketOpenSSL* socket = GetClientSocketFromSSL(ssl); |
- session_cache_.OnSessionAdded(socket->host_and_port(), session); |
+ session_cache_.OnSessionAdded(socket->host_and_port(), |
+ socket->ssl_session_cache_shard(), |
+ session); |
return 1; // 1 => We took ownership of |session|. |
} |
@@ -360,8 +380,11 @@ class SSLContext { |
// SSLClientSocketOpenSSL object from an SSL instance. |
int ssl_socket_data_index_; |
- crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ssl_ctx_; |
+ // session_cache_ must appear before |ssl_ctx_| because the destruction of |
+ // |ssl_ctx_| may trigger callbacks into |session_cache_|. Therefore, |
+ // |session_cache_| must be destructed after |ssl_ctx_|. |
SSLSessionCache session_cache_; |
+ crypto::ScopedOpenSSL<SSL_CTX, SSL_CTX_free> ssl_ctx_; |
}; |
// Utility to construct the appropriate set & clear masks for use the OpenSSL |
@@ -379,6 +402,12 @@ struct SslSetClearMask { |
} // namespace |
+// static |
+void SSLClientSocket::ClearSessionCache() { |
+ SSLContext* context = SSLContext::GetInstance(); |
+ context->session_cache()->Flush(); |
+} |
+ |
SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( |
ClientSocketHandle* transport_socket, |
const HostPortPair& host_and_port, |
@@ -394,6 +423,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( |
transport_(transport_socket), |
host_and_port_(host_and_port), |
ssl_config_(ssl_config), |
+ ssl_session_cache_shard_(context.ssl_session_cache_shard), |
trying_cached_session_(false), |
npn_status_(kNextProtoUnsupported), |
net_log_(transport_socket->socket()->NetLog()) { |
@@ -418,7 +448,8 @@ bool SSLClientSocketOpenSSL::Init() { |
return false; |
trying_cached_session_ = |
- context->session_cache()->SetSSLSession(ssl_, host_and_port_); |
+ context->session_cache()->SetSSLSession(ssl_, host_and_port_, |
+ ssl_session_cache_shard_); |
BIO* ssl_bio = NULL; |
// 0 => use default buffer sizes. |
@@ -651,6 +682,9 @@ int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { |
void SSLClientSocketOpenSSL::Disconnect() { |
if (ssl_) { |
+ // Calling SSL_shutdown prevents the session from being marked as |
+ // unresumable. |
+ SSL_shutdown(ssl_); |
SSL_free(ssl_); |
ssl_ = NULL; |
} |