Index: net/socket/ssl_server_socket_openssl.cc |
diff --git a/net/socket/ssl_server_socket_openssl.cc b/net/socket/ssl_server_socket_openssl.cc |
index fe3a3694069ef21c987acd99c6d463b7527f9768..c261a721f20f756e87654c313e85f7132f3cdd2a 100644 |
--- a/net/socket/ssl_server_socket_openssl.cc |
+++ b/net/socket/ssl_server_socket_openssl.cc |
@@ -14,13 +14,125 @@ |
#include "crypto/rsa_private_key.h" |
#include "crypto/scoped_openssl_types.h" |
#include "net/base/net_errors.h" |
+#include "net/cert/cert_verifier.h" |
+#include "net/cert/cert_verify_result.h" |
#include "net/ssl/openssl_ssl_util.h" |
#include "net/ssl/scoped_openssl_types.h" |
+#include "net/ssl/ssl_connection_status_flags.h" |
+#include "net/ssl/ssl_info.h" |
#define GotoState(s) next_handshake_state_ = s |
namespace net { |
+namespace { |
+ |
+// TODO(dougsteed) These definitions copied from ssl_client_socket_openssl.cc. |
+// Might want to consider putting them in a common place. |
+void FreeX509Stack(STACK_OF(X509) * ptr) { |
+ sk_X509_pop_free(ptr, X509_free); |
+} |
+ |
+void FreeX509NameStack(STACK_OF(X509_NAME) * ptr) { |
+ sk_X509_NAME_pop_free(ptr, X509_NAME_free); |
+} |
+ |
+typedef crypto::ScopedOpenSSL<X509_NAME, X509_NAME_free> ScopedX509Name; |
+typedef crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack> ScopedX509Stack; |
+typedef crypto::ScopedOpenSSL<STACK_OF(X509_NAME), FreeX509NameStack> |
+ ScopedX509NameStack; |
+ |
+#if OPENSSL_VERSION_NUMBER < 0x1000103fL |
+// This method doesn't seem to have made it into the OpenSSL headers. |
+unsigned long SSL_CIPHER_get_id(const SSL_CIPHER* cipher) { |
+ return cipher->id; |
+} |
+#endif |
+ |
+// Used for encoding the |connection_status| field of an SSLInfo object. |
+int EncodeSSLConnectionStatus(int cipher_suite, int compression, int version) { |
+ return (cipher_suite & SSL_CONNECTION_CIPHERSUITE_MASK) | |
+ ((compression & SSL_CONNECTION_COMPRESSION_MASK) |
+ << SSL_CONNECTION_COMPRESSION_SHIFT) | |
+ ((version & SSL_CONNECTION_VERSION_MASK) |
+ << SSL_CONNECTION_VERSION_SHIFT); |
+} |
+ |
+// Returns the net SSL version number (see ssl_connection_status_flags.h) for |
+// this SSL connection. |
+int GetNetSSLVersion(SSL* ssl) { |
+ switch (SSL_version(ssl)) { |
+ case SSL2_VERSION: |
+ return SSL_CONNECTION_VERSION_SSL2; |
+ case SSL3_VERSION: |
+ return SSL_CONNECTION_VERSION_SSL3; |
+ case TLS1_VERSION: |
+ return SSL_CONNECTION_VERSION_TLS1; |
+ case 0x0302: |
+ return SSL_CONNECTION_VERSION_TLS1_1; |
+ case 0x0303: |
+ return SSL_CONNECTION_VERSION_TLS1_2; |
+ default: |
+ return SSL_CONNECTION_VERSION_UNKNOWN; |
+ } |
+} |
+ |
+bool GetX509AsDER(X509* cert, base::StringPiece* sp) { |
+ unsigned char* cert_data = NULL; |
+ int cert_data_length = i2d_X509(cert, &cert_data); |
+ if (!cert_data_length || !cert_data) { |
+ return false; |
+ } |
+ sp->set(reinterpret_cast<char*>(cert_data), cert_data_length); |
+ return true; |
+} |
+ |
+scoped_refptr<X509Certificate> CreateX509Certificate(X509* cert, |
+ STACK_OF(X509) * chain) { |
+ DCHECK(cert); |
+ std::vector<base::StringPiece> der_chain; |
+ base::StringPiece der_cert; |
+ scoped_refptr<X509Certificate> client_cert; |
+ if (!GetX509AsDER(cert, &der_cert)) |
+ return client_cert; |
+ der_chain.push_back(der_cert); |
+ |
+ ScopedX509Stack openssl_chain(X509_chain_up_ref(chain)); |
+ for (size_t i = 0; i < sk_X509_num(openssl_chain.get()); ++i) { |
+ X509* x = sk_X509_value(openssl_chain.get(), i); |
+ if (GetX509AsDER(x, &der_cert)) { |
+ der_chain.push_back(der_cert); |
+ } |
+ } |
+ |
+ client_cert = X509Certificate::CreateFromDERCertChain(der_chain); |
+ |
+ for (size_t i = 0; i < der_chain.size(); ++i) { |
+ OPENSSL_free(const_cast<char*>(der_chain[i].data())); |
+ } |
+ if (der_chain.size() - 1 != |
+ static_cast<size_t>(sk_X509_num(openssl_chain.get()))) { |
+ client_cert = NULL; |
+ } |
+ return client_cert; |
+} |
+ |
+void DoNothingOnCompletion(int ignore) {} |
+ |
+ScopedX509 OSCertHandleToOpenSSL(X509Certificate::OSCertHandle os_handle) { |
+#if defined(USE_OPENSSL_CERTS) |
+ return ScopedX509(X509Certificate::DupOSCertHandle(os_handle)); |
+#else |
+ std::string der_encoded; |
+ if (!X509Certificate::GetDEREncoded(os_handle, &der_encoded)) |
+ return ScopedX509(); |
+ const uint8_t* bytes = reinterpret_cast<const uint8_t*>(der_encoded.data()); |
+ return ScopedX509(d2i_X509(NULL, &bytes, der_encoded.size())); |
+#endif |
+} |
+ |
+} // namespace |
+ |
void EnableSSLServerSockets() { |
// No-op because CreateSSLServerSocket() calls crypto::EnsureOpenSSLInit(). |
} |
@@ -52,7 +164,9 @@ SSLServerSocketOpenSSL::SSLServerSocketOpenSSL( |
ssl_config_(ssl_config), |
cert_(certificate), |
next_handshake_state_(STATE_NONE), |
- completed_handshake_(false) { |
+ completed_handshake_(false), |
+ client_cert_ca_list_(), |
+ client_cert_verifier_(NULL) { |
// TODO(byungchul): Need a better way to clone a key. |
std::vector<uint8> key_bytes; |
CHECK(key->ExportPrivateKey(&key_bytes)); |
@@ -99,6 +213,20 @@ int SSLServerSocketOpenSSL::Handshake(const CompletionCallback& callback) { |
return rv > OK ? OK : rv; |
} |
+void SSLServerSocketOpenSSL::SetRequireClientCert(bool require_client_cert) { |
+ ssl_config_.require_client_cert = require_client_cert; |
+} |
+ |
+void SSLServerSocketOpenSSL::SetClientCertCAList( |
+ const CertificateList& client_cert_ca_list) { |
+ client_cert_ca_list_ = client_cert_ca_list; |
+} |
+ |
+void SSLServerSocketOpenSSL::SetClientCertVerifier( |
+ CertVerifier* client_cert_verifier) { |
+ client_cert_verifier_ = client_cert_verifier; |
+} |
+ |
int SSLServerSocketOpenSSL::ExportKeyingMaterial( |
const base::StringPiece& label, |
bool has_context, |
@@ -244,8 +372,31 @@ NextProto SSLServerSocketOpenSSL::GetNegotiatedProtocol() const { |
} |
bool SSLServerSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { |
- NOTIMPLEMENTED(); |
- return false; |
+ ssl_info->Reset(); |
+ if (!completed_handshake_) { |
+ return false; |
+ } |
+ ExtractClientCert(); |
+ ssl_info->cert = client_cert_; |
+ ssl_info->client_cert_sent = |
+ ssl_config_.require_client_cert && client_cert_.get(); |
+ |
+ const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_); |
+ CHECK(cipher); |
+ ssl_info->security_bits = SSL_CIPHER_get_bits(cipher, NULL); |
+ |
+ ssl_info->connection_status = |
+ EncodeSSLConnectionStatus(SSL_CIPHER_get_id(cipher), |
+ 0 /* no compression */, GetNetSSLVersion(ssl_)); |
+ |
+ if (!SSL_get_secure_renegotiation_support(ssl_)) |
+ ssl_info->connection_status |= SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION; |
+ |
+ ssl_info->handshake_type = SSL_session_reused(ssl_) |
+ ? SSLInfo::HANDSHAKE_RESUME |
+ : SSLInfo::HANDSHAKE_FULL; |
+ |
+ return true; |
} |
void SSLServerSocketOpenSSL::GetConnectionAttempts( |
@@ -573,6 +724,13 @@ int SSLServerSocketOpenSSL::DoHandshake() { |
OpenSSLErrorInfo error_info; |
net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info); |
+ // This hack is necessary because the mapping of SSL error codes to |
+ // net_errors assumes (correctly for client sockets, but erroneously for |
+ // server sockets) that peer cert verification failure can only occur if |
+ // the cert changed during a renego. |
+ if (net_error == ERR_SSL_SERVER_CERT_CHANGED) |
+ net_error = ERR_BAD_SSL_CLIENT_AUTH_CERT; |
+ |
// If not done, stay in this state |
if (net_error == ERR_IO_PENDING) { |
GotoState(STATE_HANDSHAKE); |
@@ -618,10 +776,9 @@ int SSLServerSocketOpenSSL::Init() { |
crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); |
ScopedSSL_CTX ssl_ctx(SSL_CTX_new(SSLv23_server_method())); |
- |
if (ssl_config_.require_client_cert) |
SSL_CTX_set_verify(ssl_ctx.get(), SSL_VERIFY_PEER, NULL); |
- |
+ SSL_CTX_set_cert_verify_callback(ssl_ctx.get(), CertVerifyCallback, this); |
ssl_ = SSL_new(ssl_ctx.get()); |
if (!ssl_) |
return ERR_UNEXPECTED; |
@@ -714,7 +871,59 @@ int SSLServerSocketOpenSSL::Init() { |
LOG_IF(WARNING, rv != 1) << "SSL_set_cipher_list('" << command |
<< "') returned " << rv; |
+ if (ssl_config_.require_client_cert) { |
+ if (client_cert_verifier_) |
+ ssl_->verify_mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT; |
+ if (!client_cert_ca_list_.empty()) { |
+ ScopedX509NameStack stack(sk_X509_NAME_new_null()); |
+ for (CertificateList::iterator it = client_cert_ca_list_.begin(); |
+ it != client_cert_ca_list_.end(); it++) { |
+ ScopedX509 ca_cert = OSCertHandleToOpenSSL(it->get()->os_cert_handle()); |
+ ScopedX509Name subj(X509_NAME_dup(ca_cert->cert_info->subject)); |
+ sk_X509_NAME_push(stack.get(), subj.release()); |
+ } |
+ SSL_set_client_CA_list(ssl_, stack.release()); |
+ } |
+ } |
+ |
return OK; |
} |
+void SSLServerSocketOpenSSL::ExtractClientCert() { |
+ if (client_cert_.get() || !completed_handshake_) { |
+ return; |
+ } |
+ X509* cert = SSL_get_peer_certificate(ssl_); |
+ STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_); |
+ client_cert_ = CreateX509Certificate(cert, chain); |
+} |
+ |
+// static |
+int SSLServerSocketOpenSSL::CertVerifyCallback(X509_STORE_CTX* store_ctx, |
+ void* arg) { |
+ SSLServerSocketOpenSSL* self = reinterpret_cast<SSLServerSocketOpenSSL*>(arg); |
+ DCHECK(self); |
+ if (!self->client_cert_verifier_) |
+ return 1; |
+ SSL* ssl = reinterpret_cast<SSL*>(X509_STORE_CTX_get_ex_data( |
+ store_ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); |
+ DCHECK(ssl); |
+ X509* x = store_ctx->cert; |
+ STACK_OF(X509)* chain = store_ctx->chain; |
+ scoped_refptr<X509Certificate> client_cert(CreateX509Certificate(x, chain)); |
+ |
+ CertVerifyResult ignore_result; |
+ scoped_ptr<CertVerifier::Request> ignore_async; |
+ int res = self->client_cert_verifier_->Verify( |
+ client_cert.get(), std::string(), std::string(), 0, NULL, &ignore_result, |
+ base::Bind(&DoNothingOnCompletion), &ignore_async, self->net_log_); |
+ if (res == OK) { |
+ self->client_cert_ = client_cert; |
+ return 1; |
+ } else { |
+ X509_STORE_CTX_set_error(store_ctx, X509_V_ERR_CERT_REJECTED); |
+ return 0; |
+ } |
+} |
+ |
} // namespace net |