Index: net/ssl/client_cert_store_win.cc |
diff --git a/net/ssl/client_cert_store_win.cc b/net/ssl/client_cert_store_win.cc |
index fec9c33ae67b7d108e0468e4d686e606822213a1..b82a94b826f248dbdedb8773b0ab818543ee1968 100644 |
--- a/net/ssl/client_cert_store_win.cc |
+++ b/net/ssl/client_cert_store_win.cc |
@@ -11,15 +11,55 @@ |
#include <windows.h> |
#include <security.h> |
+#include "base/bind.h" |
+#include "base/bind_helpers.h" |
#include "base/callback.h" |
#include "base/logging.h" |
+#include "base/memory/ptr_util.h" |
+#include "base/task_runner_util.h" |
+#include "base/threading/thread_task_runner_handle.h" |
#include "crypto/wincrypt_shim.h" |
#include "net/cert/x509_util.h" |
+#include "net/ssl/ssl_platform_key_win.h" |
+#include "net/ssl/ssl_private_key.h" |
namespace net { |
namespace { |
+class ClientCertIdentityWin : public ClientCertIdentity { |
+ public: |
+ // Takes ownership of |cert_context|. |
+ ClientCertIdentityWin( |
+ scoped_refptr<net::X509Certificate> cert, |
+ PCCERT_CONTEXT cert_context, |
+ scoped_refptr<base::SingleThreadTaskRunner> key_task_runner) |
+ : ClientCertIdentity(std::move(cert)), |
+ cert_context_(cert_context), |
+ key_task_runner_(key_task_runner) {} |
+ ~ClientCertIdentityWin() override { |
+ CertFreeCertificateContext(cert_context_); |
+ } |
+ |
+ void AcquirePrivateKey( |
+ const base::Callback<void(scoped_refptr<SSLPrivateKey>)>& |
+ private_key_callback) override { |
+ if (base::PostTaskAndReplyWithResult( |
+ key_task_runner_.get(), FROM_HERE, |
+ base::Bind(&FetchClientCertPrivateKey, |
+ base::Unretained(certificate()), cert_context_), |
+ private_key_callback)) { |
+ return; |
+ } |
+ // If the task could not be posted, behave as if there was no key. |
+ private_key_callback.Run(nullptr); |
+ } |
+ |
+ private: |
+ PCCERT_CONTEXT cert_context_; |
+ scoped_refptr<base::SingleThreadTaskRunner> key_task_runner_; |
+}; |
+ |
// Callback required by Windows API function CertFindChainInStore(). In addition |
// to filtering by extended/enhanced key usage, we do not show expired |
// certificates and require digital signature usage in the key usage extension. |
@@ -65,8 +105,11 @@ static BOOL WINAPI ClientCertFindCallback(PCCERT_CONTEXT cert_context, |
void GetClientCertsImpl(HCERTSTORE cert_store, |
const SSLCertRequestInfo& request, |
- CertificateList* selected_certs) { |
- selected_certs->clear(); |
+ ClientCertIdentityList* selected_identities) { |
+ selected_identities->clear(); |
+ |
+ scoped_refptr<base::SingleThreadTaskRunner> current_thread = |
+ base::ThreadTaskRunnerHandle::Get(); |
const size_t auth_count = request.cert_authorities.size(); |
std::vector<CERT_NAME_BLOB> issuers(auth_count); |
@@ -149,15 +192,19 @@ void GetClientCertsImpl(HCERTSTORE cert_store, |
// pair<X509Certificate, SSLPrivateKeyCallback>. |
scoped_refptr<X509Certificate> cert = X509Certificate::CreateFromHandle( |
cert_context2, intermediates); |
- if (cert) |
- selected_certs->push_back(std::move(cert)); |
- CertFreeCertificateContext(cert_context2); |
+ if (cert) { |
+ selected_identities->push_back(base::MakeUnique<ClientCertIdentityWin>( |
+ std::move(cert), |
+ cert_context2, // Takes ownership of |cert_context2|. |
+ current_thread)); // The key must be acquired on the same thread, as |
+ // the PCCERT_CONTEXT may not be thread safe. |
+ } |
for (size_t i = 0; i < intermediates.size(); ++i) |
CertFreeCertificateContext(intermediates[i]); |
} |
- std::sort(selected_certs->begin(), selected_certs->end(), |
- x509_util::ClientCertSorter()); |
+ std::sort(selected_identities->begin(), selected_identities->end(), |
+ ClientCertIdentitySorter()); |
} |
} // namespace |
@@ -174,13 +221,13 @@ ClientCertStoreWin::~ClientCertStoreWin() {} |
void ClientCertStoreWin::GetClientCerts( |
const SSLCertRequestInfo& request, |
const ClientCertListCallback& callback) { |
- CertificateList selected_certs; |
+ ClientCertIdentityList selected_identities; |
if (cert_store_) { |
// Use the existing client cert store. Note: Under some situations, |
// it's possible for this to return certificates that aren't usable |
// (see below). |
- GetClientCertsImpl(cert_store_, request, &selected_certs); |
- callback.Run(std::move(selected_certs)); |
+ GetClientCertsImpl(cert_store_, request, &selected_identities); |
+ callback.Run(std::move(selected_identities)); |
return; |
} |
@@ -191,18 +238,18 @@ void ClientCertStoreWin::GetClientCerts( |
ScopedHCERTSTORE my_cert_store(CertOpenSystemStore(NULL, L"MY")); |
if (!my_cert_store) { |
PLOG(ERROR) << "Could not open the \"MY\" system certificate store: "; |
- callback.Run(CertificateList()); |
+ callback.Run(ClientCertIdentityList()); |
return; |
} |
- GetClientCertsImpl(my_cert_store, request, &selected_certs); |
- callback.Run(std::move(selected_certs)); |
+ GetClientCertsImpl(my_cert_store, request, &selected_identities); |
+ callback.Run(std::move(selected_identities)); |
} |
bool ClientCertStoreWin::SelectClientCertsForTesting( |
const CertificateList& input_certs, |
const SSLCertRequestInfo& request, |
- CertificateList* selected_certs) { |
+ ClientCertIdentityList* selected_identities) { |
ScopedHCERTSTORE test_store(CertOpenStore(CERT_STORE_PROV_MEMORY, 0, NULL, 0, |
NULL)); |
if (!test_store) |
@@ -232,7 +279,7 @@ bool ClientCertStoreWin::SelectClientCertsForTesting( |
return false; |
} |
- GetClientCertsImpl(test_store.get(), request, selected_certs); |
+ GetClientCertsImpl(test_store.get(), request, selected_identities); |
return true; |
} |