Chromium Code Reviews| Index: net/base/origin_bound_cert_service.cc |
| diff --git a/net/base/origin_bound_cert_service.cc b/net/base/origin_bound_cert_service.cc |
| index 4d1af288fef99089d50925f5c3a857214c2a861f..820c0975614a74101f40f5e1ff1dbf444bc84692 100644 |
| --- a/net/base/origin_bound_cert_service.cc |
| +++ b/net/base/origin_bound_cert_service.cc |
| @@ -4,6 +4,7 @@ |
| #include "net/base/origin_bound_cert_service.h" |
| +#include <algorithm> |
| #include <limits> |
| #include "base/compiler_specific.h" |
| @@ -15,6 +16,7 @@ |
| #include "base/rand_util.h" |
| #include "base/stl_util.h" |
| #include "base/threading/worker_pool.h" |
| +#include "crypto/ec_private_key.h" |
| #include "crypto/rsa_private_key.h" |
| #include "net/base/net_errors.h" |
| #include "net/base/origin_bound_cert_store.h" |
| @@ -32,15 +34,27 @@ namespace { |
| const int kKeySizeInBits = 1024; |
| const int kValidityPeriodInDays = 365; |
| +bool IsSupportedCertType(uint8 type) { |
| + switch(type) { |
| + case CLIENT_CERT_RSA_SIGN: |
| + case CLIENT_CERT_ECDSA_SIGN: |
| + return true; |
| + default: |
| + return false; |
| + } |
| +} |
| + |
| } // namespace |
| // Represents the output and result callback of a request. |
| class OriginBoundCertServiceRequest { |
| public: |
| OriginBoundCertServiceRequest(const CompletionCallback& callback, |
| + SSLClientCertType* type, |
| std::string* private_key, |
| std::string* cert) |
| : callback_(callback), |
| + type_(type), |
| private_key_(private_key), |
| cert_(cert) { |
| } |
| @@ -48,6 +62,7 @@ class OriginBoundCertServiceRequest { |
| // Ensures that the result callback will never be made. |
| void Cancel() { |
| callback_.Reset(); |
| + type_ = NULL; |
| private_key_ = NULL; |
| cert_ = NULL; |
| } |
| @@ -55,9 +70,11 @@ class OriginBoundCertServiceRequest { |
| // Copies the contents of |private_key| and |cert| to the caller's output |
| // arguments and calls the callback. |
| void Post(int error, |
| + SSLClientCertType type, |
| const std::string& private_key, |
| const std::string& cert) { |
| if (!callback_.is_null()) { |
| + *type_ = type; |
| *private_key_ = private_key; |
| *cert_ = cert; |
| callback_.Run(error); |
| @@ -69,6 +86,7 @@ class OriginBoundCertServiceRequest { |
| private: |
| CompletionCallback callback_; |
| + SSLClientCertType* type_; |
| std::string* private_key_; |
| std::string* cert_; |
| }; |
| @@ -80,8 +98,10 @@ class OriginBoundCertServiceWorker { |
| public: |
| OriginBoundCertServiceWorker( |
| const std::string& origin, |
| + SSLClientCertType type, |
| OriginBoundCertService* origin_bound_cert_service) |
| : origin_(origin), |
| + type_(type), |
| serial_number_(base::RandInt(0, std::numeric_limits<int>::max())), |
| origin_loop_(MessageLoop::current()), |
| origin_bound_cert_service_(origin_bound_cert_service), |
| @@ -110,6 +130,7 @@ class OriginBoundCertServiceWorker { |
| void Run() { |
| // Runs on a worker thread. |
| error_ = OriginBoundCertService::GenerateCert(origin_, |
| + type_, |
| serial_number_, |
| &private_key_, |
| &cert_); |
| @@ -136,7 +157,7 @@ class OriginBoundCertServiceWorker { |
| // memory leaks or worse errors. |
| base::AutoLock locked(lock_); |
| if (!canceled_) { |
| - origin_bound_cert_service_->HandleResult(origin_, error_, |
| + origin_bound_cert_service_->HandleResult(origin_, error_, type_, |
| private_key_, cert_); |
| } |
| } |
| @@ -169,6 +190,7 @@ class OriginBoundCertServiceWorker { |
| } |
| const std::string origin_; |
| + const SSLClientCertType type_; |
| // Note that serial_number_ must be initialized on a non-worker thread |
| // (see documentation for OriginBoundCertService::GenerateCert). |
| uint32 serial_number_; |
| @@ -195,8 +217,9 @@ class OriginBoundCertServiceWorker { |
| // origin message loop. |
| class OriginBoundCertServiceJob { |
| public: |
| - explicit OriginBoundCertServiceJob(OriginBoundCertServiceWorker* worker) |
| - : worker_(worker) { |
| + OriginBoundCertServiceJob(OriginBoundCertServiceWorker* worker, |
| + SSLClientCertType type) |
| + : worker_(worker), type_(type) { |
| } |
| ~OriginBoundCertServiceJob() { |
| @@ -206,19 +229,23 @@ class OriginBoundCertServiceJob { |
| } |
| } |
| + SSLClientCertType type() const { return type_; } |
| + |
| void AddRequest(OriginBoundCertServiceRequest* request) { |
| requests_.push_back(request); |
| } |
| void HandleResult(int error, |
| + SSLClientCertType type, |
| const std::string& private_key, |
| const std::string& cert) { |
| worker_ = NULL; |
| - PostAll(error, private_key, cert); |
| + PostAll(error, type, private_key, cert); |
| } |
| private: |
| void PostAll(int error, |
| + SSLClientCertType type, |
| const std::string& private_key, |
| const std::string& cert) { |
| std::vector<OriginBoundCertServiceRequest*> requests; |
| @@ -226,7 +253,7 @@ class OriginBoundCertServiceJob { |
| for (std::vector<OriginBoundCertServiceRequest*>::iterator |
| i = requests.begin(); i != requests.end(); i++) { |
| - (*i)->Post(error, private_key, cert); |
| + (*i)->Post(error, type, private_key, cert); |
| // Post() causes the OriginBoundCertServiceRequest to delete itself. |
| } |
| } |
| @@ -244,8 +271,12 @@ class OriginBoundCertServiceJob { |
| std::vector<OriginBoundCertServiceRequest*> requests_; |
| OriginBoundCertServiceWorker* worker_; |
| + SSLClientCertType type_; |
| }; |
| +// static |
| +const char OriginBoundCertService::kEPKIPassword[] = ""; |
| + |
| OriginBoundCertService::OriginBoundCertService( |
| OriginBoundCertStore* origin_bound_cert_store) |
| : origin_bound_cert_store_(origin_bound_cert_store), |
| @@ -259,43 +290,80 @@ OriginBoundCertService::~OriginBoundCertService() { |
| int OriginBoundCertService::GetOriginBoundCert( |
| const std::string& origin, |
| + const std::vector<uint8>& requested_types, |
| + SSLClientCertType* type, |
| std::string* private_key, |
| std::string* cert, |
| const CompletionCallback& callback, |
| RequestHandle* out_req) { |
| DCHECK(CalledOnValidThread()); |
| - if (callback.is_null() || !private_key || !cert || origin.empty()) { |
| + if (callback.is_null() || !private_key || !cert || origin.empty() || |
| + requested_types.empty()) { |
| *out_req = NULL; |
| return ERR_INVALID_ARGUMENT; |
| } |
| + SSLClientCertType preferred_type = CLIENT_CERT_INVALID_TYPE; |
| + for (size_t i = 0; i < requested_types.size(); ++i) { |
| + if (IsSupportedCertType(requested_types[i])) { |
| + preferred_type = static_cast<SSLClientCertType>(requested_types[i]); |
| + break; |
| + } |
| + } |
| + if (preferred_type == CLIENT_CERT_INVALID_TYPE) { |
| + // None of the requested types are supported. |
| + *out_req = NULL; |
| + return ERR_ORIGIN_BOUND_CERT_TYPE_UNSUPPORTED; |
| + } |
| + |
| requests_++; |
| - // Check if an origin bound cert already exists for this origin. |
| + // Check if an origin bound cert of an acceptable type already exists for this |
| + // origin. |
| if (origin_bound_cert_store_->GetOriginBoundCert(origin, |
| + type, |
| private_key, |
| cert)) { |
| - cert_store_hits_++; |
| - *out_req = NULL; |
| - return OK; |
| + if (IsSupportedCertType(*type) && |
| + std::find(requested_types.begin(), requested_types.end(), *type) != |
| + requested_types.end()) { |
| + cert_store_hits_++; |
| + *out_req = NULL; |
| + return OK; |
| + } |
| + DVLOG(1) << "Cert store had cert of wrong type " << *type << " for " |
| + << origin; |
| } |
| // |origin_bound_cert_store_| has no cert for this origin. See if an |
| // identical request is currently in flight. |
| - OriginBoundCertServiceJob* job; |
| - std::map<std::string, OriginBoundCertServiceJob*>::const_iterator j; |
| + OriginBoundCertServiceJob* job = NULL; |
| + std::map<std::string, OriginBoundCertServiceJob*>::iterator j; |
|
wtc
2011/12/06 00:18:05
We can keep the initialization of 'job' to NULL. B
mattm
2011/12/06 00:54:01
Done.
|
| j = inflight_.find(origin); |
| if (j != inflight_.end()) { |
| // An identical request is in flight already. We'll just attach our |
| // callback. |
| - inflight_joins_++; |
| job = j->second; |
| + // Check that the job is for an acceptable type of cert. |
| + if (std::find(requested_types.begin(), requested_types.end(), job->type()) |
| + == requested_types.end()) { |
| + DVLOG(1) << "Found inflight job of wrong type " << job->type() |
| + << " for " << origin; |
| + *out_req = NULL; |
| + // If we get here, the server is asking for different types of certs in |
| + // short succession. This probably means the server is broken or |
| + // misconfigured. Since we only store one type of cert per origin, we |
| + // are unable to handle this well. Just return an error and let the first |
| + // job finish. |
| + return ERR_ORIGIN_BOUND_CERT_GENERATION_TYPE_MISMATCH; |
| + } |
| + inflight_joins_++; |
| } else { |
| // Need to make a new request. |
| OriginBoundCertServiceWorker* worker = |
| - new OriginBoundCertServiceWorker(origin, this); |
| - job = new OriginBoundCertServiceJob(worker); |
| + new OriginBoundCertServiceWorker(origin, preferred_type, this); |
| + job = new OriginBoundCertServiceJob(worker, preferred_type); |
| if (!worker->Start()) { |
| delete job; |
| delete worker; |
| @@ -308,7 +376,7 @@ int OriginBoundCertService::GetOriginBoundCert( |
| } |
| OriginBoundCertServiceRequest* request = |
| - new OriginBoundCertServiceRequest(callback, private_key, cert); |
| + new OriginBoundCertServiceRequest(callback, type, private_key, cert); |
| job->AddRequest(request); |
| *out_req = request; |
| return ERR_IO_PENDING; |
| @@ -316,31 +384,64 @@ int OriginBoundCertService::GetOriginBoundCert( |
| // static |
| int OriginBoundCertService::GenerateCert(const std::string& origin, |
| + SSLClientCertType type, |
| uint32 serial_number, |
| std::string* private_key, |
| std::string* cert) { |
| - scoped_ptr<crypto::RSAPrivateKey> key( |
| - crypto::RSAPrivateKey::Create(kKeySizeInBits)); |
| - if (!key.get()) { |
| - LOG(WARNING) << "Unable to create key pair for client"; |
| - return ERR_KEY_GENERATION_FAILED; |
| - } |
| std::string der_cert; |
| - if (!x509_util::CreateOriginBoundCertRSA( |
| - key.get(), |
| - origin, |
| - serial_number, |
| - base::TimeDelta::FromDays(kValidityPeriodInDays), |
| - &der_cert)) { |
| - LOG(WARNING) << "Unable to create x509 cert for client"; |
| - return ERR_ORIGIN_BOUND_CERT_GENERATION_FAILED; |
| - } |
| - |
| std::vector<uint8> private_key_info; |
| - if (!key->ExportPrivateKey(&private_key_info)) { |
| - LOG(WARNING) << "Unable to export private key"; |
| - return ERR_PRIVATE_KEY_EXPORT_FAILED; |
| + switch (type) { |
| + case CLIENT_CERT_RSA_SIGN: { |
| + scoped_ptr<crypto::RSAPrivateKey> key( |
| + crypto::RSAPrivateKey::Create(kKeySizeInBits)); |
| + if (!key.get()) { |
| + DLOG(ERROR) << "Unable to create key pair for client"; |
| + return ERR_KEY_GENERATION_FAILED; |
| + } |
| + if (!x509_util::CreateOriginBoundCertRSA( |
| + key.get(), |
| + origin, |
| + serial_number, |
| + base::TimeDelta::FromDays(kValidityPeriodInDays), |
| + &der_cert)) { |
| + DLOG(ERROR) << "Unable to create x509 cert for client"; |
| + return ERR_ORIGIN_BOUND_CERT_GENERATION_FAILED; |
| + } |
| + |
| + if (!key->ExportPrivateKey(&private_key_info)) { |
| + DLOG(ERROR) << "Unable to export private key"; |
| + return ERR_PRIVATE_KEY_EXPORT_FAILED; |
| + } |
| + break; |
| + } |
| + case CLIENT_CERT_ECDSA_SIGN: { |
| + scoped_ptr<crypto::ECPrivateKey> key(crypto::ECPrivateKey::Create()); |
| + if (!key.get()) { |
| + DLOG(ERROR) << "Unable to create key pair for client"; |
| + return ERR_KEY_GENERATION_FAILED; |
| + } |
| + if (!x509_util::CreateOriginBoundCertEC( |
| + key.get(), |
| + origin, |
| + serial_number, |
| + base::TimeDelta::FromDays(kValidityPeriodInDays), |
| + &der_cert)) { |
| + DLOG(ERROR) << "Unable to create x509 cert for client"; |
| + return ERR_ORIGIN_BOUND_CERT_GENERATION_FAILED; |
| + } |
| + |
| + if (!key->ExportEncryptedPrivateKey( |
| + kEPKIPassword, 1, &private_key_info)) { |
| + DLOG(ERROR) << "Unable to export private key"; |
| + return ERR_PRIVATE_KEY_EXPORT_FAILED; |
| + } |
| + break; |
| + } |
| + default: |
| + NOTREACHED(); |
| + return ERR_INVALID_ARGUMENT; |
| } |
| + |
| // TODO(rkn): Perhaps ExportPrivateKey should be changed to output a |
| // std::string* to prevent this copying. |
| std::string key_out(private_key_info.begin(), private_key_info.end()); |
| @@ -360,12 +461,13 @@ void OriginBoundCertService::CancelRequest(RequestHandle req) { |
| // HandleResult is called by OriginBoundCertServiceWorker on the origin message |
| // loop. It deletes OriginBoundCertServiceJob. |
| void OriginBoundCertService::HandleResult(const std::string& origin, |
| - int error, |
| - const std::string& private_key, |
| - const std::string& cert) { |
| + int error, |
| + SSLClientCertType type, |
| + const std::string& private_key, |
| + const std::string& cert) { |
| DCHECK(CalledOnValidThread()); |
| - origin_bound_cert_store_->SetOriginBoundCert(origin, private_key, cert); |
| + origin_bound_cert_store_->SetOriginBoundCert(origin, type, private_key, cert); |
| std::map<std::string, OriginBoundCertServiceJob*>::iterator j; |
| j = inflight_.find(origin); |
| @@ -376,7 +478,7 @@ void OriginBoundCertService::HandleResult(const std::string& origin, |
| OriginBoundCertServiceJob* job = j->second; |
| inflight_.erase(j); |
| - job->HandleResult(error, private_key, cert); |
| + job->HandleResult(error, type, private_key, cert); |
| delete job; |
| } |