OLD | NEW |
---|---|
(Empty) | |
1 // Copyright 2014 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "chrome/browser/chromeos/platform_keys/platform_keys.h" | |
6 | |
7 #include <cryptohi.h> | |
8 | |
9 #include "base/bind.h" | |
10 #include "base/bind_helpers.h" | |
11 #include "base/callback.h" | |
12 #include "base/compiler_specific.h" | |
13 #include "base/logging.h" | |
14 #include "base/macros.h" | |
15 #include "base/single_thread_task_runner.h" | |
16 #include "base/thread_task_runner_handle.h" | |
17 #include "base/threading/worker_pool.h" | |
18 #include "chrome/browser/extensions/api/enterprise_platform_keys/enterprise_plat form_keys_api.h" | |
19 #include "chrome/browser/net/nss_context.h" | |
20 #include "crypto/rsa_private_key.h" | |
21 #include "net/base/crypto_module.h" | |
22 #include "net/base/net_errors.h" | |
23 #include "net/cert/cert_database.h" | |
24 #include "net/cert/nss_cert_database.h" | |
25 #include "net/cert/x509_certificate.h" | |
26 | |
27 namespace { | |
28 const char kErrorInternal[] = "Internal Error."; | |
29 const char kErrorKeyNotFound[] = "Key not found."; | |
30 const char kErrorCertificateNotFound[] = "Certificate could not be found."; | |
31 const char kErrorAlgorithmNotSupported[] = "Algorithm not supported."; | |
32 | |
33 // The current maximal RSA modulus length that ChromeOS's TPM supports for key | |
34 // generation. | |
35 const unsigned int kMaxRSAModulusLength = 2048; | |
36 } | |
37 | |
38 namespace chromeos { | |
39 | |
40 namespace platform_keys { | |
41 | |
42 namespace { | |
43 | |
44 // Base class to store state that is common to all NSS database operations and | |
45 // to provide convenience methods to call back. | |
46 // Keeps track of the originating task runner. | |
47 class NSSOperationState { | |
48 public: | |
49 explicit NSSOperationState(Profile* profile); | |
50 virtual ~NSSOperationState() {} | |
51 | |
52 // Called if an error occurred during the execution of the NSS operation | |
53 // described by this object. | |
54 virtual void OnError(const std::string& error_message) = 0; | |
55 | |
56 Profile* profile_; | |
57 crypto::ScopedPK11Slot slot_; | |
58 net::NSSCertDatabase* cert_db_; | |
59 | |
60 // The task runner on which the NSS operation was called. Any reply must be | |
61 // posted to this runner. | |
62 scoped_refptr<base::SingleThreadTaskRunner> origin_task_runner_; | |
63 | |
64 private: | |
65 DISALLOW_COPY_AND_ASSIGN(NSSOperationState); | |
66 }; | |
67 | |
68 typedef base::Closure GetCertDBCallback; | |
69 | |
70 // Callback of GetCertDatabase. Called back with the NSSCertDatabase associated | |
71 // to the given |token_id|. | |
72 void DidGetCertDB(const GetCertDBCallback& callback, | |
73 NSSOperationState* state, | |
74 net::NSSCertDatabase* cert_db) { | |
75 if (!cert_db) { | |
76 LOG(ERROR) << "Couldn't get NSSCertDatabase."; | |
77 state->OnError(kErrorInternal); | |
78 return; | |
79 } | |
80 | |
81 state->cert_db_ = cert_db; | |
82 state->slot_ = cert_db->GetPrivateSlot(); | |
83 if (!state->slot_) { | |
84 LOG(ERROR) << "No private slot"; | |
85 state->OnError(kErrorInternal); | |
86 return; | |
87 } | |
88 | |
89 base::WorkerPool::PostTask(FROM_HERE, callback, true /*task is slow*/); | |
mattm
2014/05/16 20:27:12
This is problematic, since the cert_db_ could go a
pneubeck (no reviews)
2014/05/19 19:17:19
Thanks for pointing this out, I though of this too
| |
90 } | |
91 | |
92 // Asynchronously fetches the NSSCertDatabase for |token_id| and passes it to | |
93 // |callback|. Will run |callback| on a worker thread. | |
94 void GetCertDatabase(const std::string& token_id, | |
95 const GetCertDBCallback& callback, | |
96 NSSOperationState* state) { | |
97 GetNSSCertDatabaseForProfile(state->profile_, | |
98 base::Bind(&DidGetCertDB, callback, state)); | |
99 } | |
100 | |
101 class GenerateRSAKeyState : public NSSOperationState { | |
102 public: | |
103 GenerateRSAKeyState(unsigned int modulus_length, | |
104 const GenerateKeyCallback& callback, | |
105 Profile* profile); | |
106 virtual ~GenerateRSAKeyState() {} | |
107 | |
108 virtual void OnError(const std::string& error_message) OVERRIDE { | |
109 CallBack(std::string() /* no public key */, error_message); | |
110 } | |
111 | |
112 void CallBack(const std::string& public_key_spki_der, | |
mattm
2014/05/16 20:27:12
Pass the FROM_HERE through the CallBack methods ar
pneubeck (no reviews)
2014/05/19 19:17:19
Done.
| |
113 const std::string& error_message) { | |
114 origin_task_runner_->PostTask( | |
115 FROM_HERE, base::Bind(callback_, public_key_spki_der, error_message)); | |
116 } | |
117 | |
118 unsigned int modulus_length_; | |
119 | |
120 private: | |
121 // Must be called on origin thread, use CallBack() therefore. | |
122 GenerateKeyCallback callback_; | |
123 }; | |
124 | |
125 class SignState : public NSSOperationState { | |
126 public: | |
127 SignState(const std::string& public_key, | |
128 const std::string& data, | |
129 const SignCallback& callback, | |
130 Profile* profile); | |
131 virtual ~SignState() {} | |
132 | |
133 virtual void OnError(const std::string& error_message) OVERRIDE { | |
134 CallBack(std::string() /* no signature */, error_message); | |
135 } | |
136 | |
137 void CallBack(const std::string& signature, | |
138 const std::string& error_message) { | |
139 origin_task_runner_->PostTask( | |
140 FROM_HERE, base::Bind(callback_, signature, error_message)); | |
141 } | |
142 | |
143 std::string public_key_; | |
144 std::string data_; | |
145 | |
146 private: | |
147 // Must be called on origin thread, use CallBack() therefore. | |
148 SignCallback callback_; | |
149 }; | |
150 | |
151 class GetCertificatesState : public NSSOperationState { | |
152 public: | |
153 GetCertificatesState(const GetCertificatesCallback& callback, | |
154 Profile* profile); | |
155 virtual ~GetCertificatesState() {} | |
156 | |
157 virtual void OnError(const std::string& error_message) OVERRIDE { | |
158 CallBack(scoped_ptr<net::CertificateList>() /* no certificates */, | |
159 error_message); | |
160 } | |
161 | |
162 void CallBack(scoped_ptr<net::CertificateList> certs, | |
163 const std::string& error_message) { | |
164 origin_task_runner_->PostTask( | |
165 FROM_HERE, base::Bind(callback_, base::Passed(&certs), error_message)); | |
166 } | |
167 | |
168 private: | |
169 // Must be called on origin thread, use CallBack() therefore. | |
170 GetCertificatesCallback callback_; | |
171 }; | |
172 | |
173 class ImportCertificateState : public NSSOperationState { | |
174 public: | |
175 ImportCertificateState(scoped_refptr<net::X509Certificate> certificate, | |
176 const ImportCertificateCallback& callback, | |
177 Profile* profile); | |
178 virtual ~ImportCertificateState() {} | |
179 | |
180 virtual void OnError(const std::string& error_message) OVERRIDE { | |
181 CallBack(error_message); | |
182 } | |
183 | |
184 void CallBack(const std::string& error_message) { | |
185 origin_task_runner_->PostTask(FROM_HERE, | |
186 base::Bind(callback_, error_message)); | |
187 } | |
188 | |
189 scoped_refptr<net::X509Certificate> certificate_; | |
190 | |
191 private: | |
192 // Must be called on origin thread, use CallBack() therefore. | |
193 ImportCertificateCallback callback_; | |
194 }; | |
195 | |
196 class RemoveCertificateState : public NSSOperationState { | |
197 public: | |
198 RemoveCertificateState(scoped_refptr<net::X509Certificate> certificate, | |
199 const RemoveCertificateCallback& callback, | |
200 Profile* profile); | |
201 virtual ~RemoveCertificateState() {} | |
202 | |
203 virtual void OnError(const std::string& error_message) OVERRIDE { | |
204 CallBack(error_message); | |
205 } | |
206 | |
207 void CallBack(const std::string& error_message) { | |
208 origin_task_runner_->PostTask(FROM_HERE, | |
209 base::Bind(callback_, error_message)); | |
210 } | |
211 | |
212 scoped_refptr<net::X509Certificate> certificate_; | |
213 | |
214 private: | |
215 // Must be called on origin thread, use CallBack() therefore. | |
216 RemoveCertificateCallback callback_; | |
217 }; | |
218 | |
219 NSSOperationState::NSSOperationState(Profile* profile) | |
220 : profile_(profile), | |
221 cert_db_(NULL), | |
222 origin_task_runner_(base::ThreadTaskRunnerHandle::Get()) { | |
223 } | |
224 | |
225 GenerateRSAKeyState::GenerateRSAKeyState(unsigned int modulus_length, | |
226 const GenerateKeyCallback& callback, | |
227 Profile* profile) | |
228 : NSSOperationState(profile), | |
229 modulus_length_(modulus_length), | |
230 callback_(callback) { | |
231 } | |
232 | |
233 SignState::SignState(const std::string& public_key, | |
234 const std::string& data, | |
235 const SignCallback& callback, | |
236 Profile* profile) | |
237 : NSSOperationState(profile), | |
238 public_key_(public_key), | |
239 data_(data), | |
240 callback_(callback) { | |
241 } | |
242 | |
243 GetCertificatesState::GetCertificatesState( | |
244 const GetCertificatesCallback& callback, | |
245 Profile* profile) | |
246 : NSSOperationState(profile), callback_(callback) { | |
247 } | |
248 | |
249 ImportCertificateState::ImportCertificateState( | |
250 scoped_refptr<net::X509Certificate> certificate, | |
251 const ImportCertificateCallback& callback, | |
252 Profile* profile) | |
253 : NSSOperationState(profile), | |
254 certificate_(certificate), | |
255 callback_(callback) { | |
256 } | |
257 | |
258 RemoveCertificateState::RemoveCertificateState( | |
259 scoped_refptr<net::X509Certificate> certificate, | |
260 const RemoveCertificateCallback& callback, | |
261 Profile* profile) | |
262 : NSSOperationState(profile), | |
263 certificate_(certificate), | |
264 callback_(callback) { | |
265 } | |
266 | |
267 // Continues generating a RSA key with the obtained NSSCertDatabase. Used by | |
268 // GenerateRSAKey(). | |
269 void GenerateRSAKeyWithDB(scoped_ptr<GenerateRSAKeyState> state) { | |
270 if (state->modulus_length_ > kMaxRSAModulusLength) { | |
271 state->OnError(kErrorAlgorithmNotSupported); | |
272 return; | |
273 } | |
274 scoped_ptr<crypto::RSAPrivateKey> rsa_key( | |
275 crypto::RSAPrivateKey::CreateSensitive(state->slot_.get(), | |
276 state->modulus_length_)); | |
277 if (!rsa_key) { | |
278 LOG(ERROR) << "Couldn't create key."; | |
279 state->OnError(kErrorInternal); | |
280 return; | |
281 } | |
282 | |
283 std::vector<uint8> public_key_spki_der; | |
284 if (!rsa_key->ExportPublicKey(&public_key_spki_der)) { | |
285 // TODO(pneubeck): Remove rsa_key from storage. | |
286 LOG(ERROR) << "Couldn't export public key."; | |
287 state->OnError(kErrorInternal); | |
288 return; | |
289 } | |
290 state->CallBack( | |
291 std::string(public_key_spki_der.begin(), public_key_spki_der.end()), | |
292 std::string() /* no error */); | |
293 } | |
294 | |
295 // Continues signing with the obtained NSSCertDatabase. Used by Sign(). | |
296 void RSASignWithDB(scoped_ptr<SignState> state) { | |
297 const uint8* public_key_uint8 = | |
298 reinterpret_cast<const uint8*>(state->public_key_.data()); | |
299 std::vector<uint8> public_key_vector( | |
300 public_key_uint8, public_key_uint8 + state->public_key_.size()); | |
301 | |
302 // TODO(pneubeck): This searches all slots. Change to look only at |slot_|. | |
303 scoped_ptr<crypto::RSAPrivateKey> rsa_key( | |
304 crypto::RSAPrivateKey::FindFromPublicKeyInfo(public_key_vector)); | |
305 if (!rsa_key) { | |
306 state->OnError(kErrorKeyNotFound); | |
307 return; | |
308 } | |
309 | |
310 SECItem sign_result = {siBuffer, NULL, 0}; | |
311 if (SEC_SignData(&sign_result, | |
312 reinterpret_cast<const unsigned char*>(state->data_.data()), | |
313 state->data_.size(), | |
314 rsa_key->key(), | |
315 SEC_OID_PKCS1_SHA1_WITH_RSA_ENCRYPTION) != SECSuccess) { | |
316 LOG(ERROR) << "Couldn't sign."; | |
317 state->OnError(kErrorInternal); | |
318 return; | |
319 } | |
320 | |
321 std::string signature(reinterpret_cast<const char*>(sign_result.data), | |
322 sign_result.len); | |
323 state->CallBack(signature, std::string() /* no error */); | |
324 } | |
325 | |
326 // Continues getting certificates with the certificates returned by | |
327 // NSSCertDatabase::ListCertsInSlot. Used by GetCertificatesWithDB(). | |
328 void DidGetCertificates(scoped_ptr<GetCertificatesState> state, | |
329 scoped_ptr<net::CertificateList> all_certs) { | |
330 scoped_ptr<net::CertificateList> client_certs(new net::CertificateList); | |
331 for (net::CertificateList::const_iterator it = all_certs->begin(); | |
332 it != all_certs->end(); | |
333 ++it) { | |
334 net::X509Certificate::OSCertHandle cert_handle = (*it)->os_cert_handle(); | |
335 crypto::ScopedPK11Slot cert_slot(PK11_KeyForCertExists(cert_handle, | |
336 NULL, // keyPtr | |
337 NULL)); // wincx | |
338 | |
339 // Keep only user certificates, i.e. certs for which the private key is | |
340 // present and stored in the queried slot. | |
341 if (cert_slot != state->slot_) | |
342 continue; | |
343 | |
344 client_certs->push_back(*it); | |
345 } | |
346 | |
347 state->CallBack(client_certs.Pass(), std::string() /* no error */); | |
348 } | |
349 | |
350 // Continues getting certificates with the obtained NSSCertDatabase. Used by | |
351 // GetCertificates(). | |
352 void GetCertificatesWithDB(scoped_ptr<GetCertificatesState> state) { | |
353 // Get the pointer to slot before base::Passed releases |state|. | |
354 PK11SlotInfo* slot = state->slot_.get(); | |
355 state->cert_db_->ListCertsInSlot( | |
356 base::Bind(&DidGetCertificates, base::Passed(&state)), slot); | |
357 } | |
358 | |
359 // Continues certificate importing with the obtained NSSCertDatabase. Used by | |
360 // ImportCertificate(). | |
361 void ImportCertificateWithDB(scoped_ptr<ImportCertificateState> state) { | |
362 net::CertDatabase* db = net::CertDatabase::GetInstance(); | |
363 | |
364 const net::Error cert_status = db->CheckUserCert(state->certificate_); | |
365 if (cert_status == net::ERR_NO_PRIVATE_KEY_FOR_CERT) { | |
366 state->OnError(kErrorKeyNotFound); | |
367 return; | |
368 } else if (cert_status != net::OK) { | |
369 state->OnError(net::ErrorToString(cert_status)); | |
370 return; | |
371 } | |
372 | |
373 const net::Error import_status = db->AddUserCert(state->certificate_.get()); | |
374 if (import_status != net::OK) { | |
375 LOG(ERROR) << "Could not import certificate."; | |
376 state->OnError(net::ErrorToString(import_status)); | |
377 return; | |
378 } | |
379 | |
380 state->CallBack(std::string() /* no error */); | |
381 } | |
382 | |
383 // Continues certificate removal with the obtained NSSCertDatabase. Used by | |
384 // RemoveCertificate(). | |
385 void RemoveCertificateWithDB(scoped_ptr<RemoveCertificateState> state) { | |
386 bool certificate_found = state->certificate_->os_cert_handle()->isperm; | |
387 bool success = state->cert_db_->DeleteCertAndKey(state->certificate_); | |
388 | |
389 // CertificateNotFound error has precedence over an internal error. | |
390 if (!certificate_found) { | |
391 state->OnError(kErrorCertificateNotFound); | |
392 return; | |
393 } | |
394 if (!success) { | |
395 state->OnError(kErrorInternal); | |
396 return; | |
397 } | |
398 | |
399 state->CallBack(std::string() /* no error */); | |
400 } | |
401 | |
402 } // namespace | |
403 | |
404 void GenerateRSAKey(const std::string& token_id, | |
405 unsigned int modulus_length, | |
406 const GenerateKeyCallback& callback, | |
407 Profile* profile) { | |
408 scoped_ptr<GenerateRSAKeyState> state( | |
409 new GenerateRSAKeyState(modulus_length, callback, profile)); | |
410 // Get the pointer to |state| before base::Passed releases |state|. | |
411 NSSOperationState* state_ptr = state.get(); | |
412 GetCertDatabase(token_id, | |
413 base::Bind(&GenerateRSAKeyWithDB, base::Passed(&state)), | |
414 state_ptr); | |
415 } | |
416 | |
417 void Sign(const std::string& token_id, | |
418 const std::string& public_key, | |
419 const std::string& data, | |
420 const SignCallback& callback, | |
421 Profile* profile) { | |
422 scoped_ptr<SignState> state( | |
423 new SignState(public_key, data, callback, profile)); | |
424 // Get the pointer to |state| before base::Passed releases |state|. | |
425 NSSOperationState* state_ptr = state.get(); | |
426 GetCertDatabase( | |
427 token_id, base::Bind(&RSASignWithDB, base::Passed(&state)), state_ptr); | |
428 } | |
429 | |
430 void GetCertificates(const std::string& token_id, | |
431 const GetCertificatesCallback& callback, | |
432 Profile* profile) { | |
433 scoped_ptr<GetCertificatesState> state( | |
434 new GetCertificatesState(callback, profile)); | |
435 // Get the pointer to |state| before base::Passed releases |state|. | |
436 NSSOperationState* state_ptr = state.get(); | |
437 GetCertDatabase(token_id, | |
438 base::Bind(&GetCertificatesWithDB, base::Passed(&state)), | |
439 state_ptr); | |
440 } | |
441 | |
442 void ImportCertificate(const std::string& token_id, | |
443 scoped_refptr<net::X509Certificate> certificate, | |
444 const ImportCertificateCallback& callback, | |
445 Profile* profile) { | |
446 scoped_ptr<ImportCertificateState> state( | |
447 new ImportCertificateState(certificate, callback, profile)); | |
448 // Get the pointer to |state| before base::Passed releases |state|. | |
449 NSSOperationState* state_ptr = state.get(); | |
450 GetCertDatabase(token_id, | |
451 base::Bind(&ImportCertificateWithDB, base::Passed(&state)), | |
452 state_ptr); | |
453 } | |
454 | |
455 void RemoveCertificate(const std::string& token_id, | |
456 scoped_refptr<net::X509Certificate> certificate, | |
457 const RemoveCertificateCallback& callback, | |
458 Profile* profile) { | |
459 scoped_ptr<RemoveCertificateState> state( | |
460 new RemoveCertificateState(certificate, callback, profile)); | |
461 // Get the pointer to |state| before base::Passed releases |state|. | |
462 NSSOperationState* state_ptr = state.get(); | |
463 GetCertDatabase(token_id, | |
464 base::Bind(&RemoveCertificateWithDB, base::Passed(&state)), | |
465 state_ptr); | |
466 } | |
467 | |
468 } // namespace platform_keys | |
469 | |
470 } // namespace chromeos | |
OLD | NEW |