Chromium Code Reviews| Index: runtime/bin/tls_socket.cc |
| diff --git a/runtime/bin/tls_socket.cc b/runtime/bin/tls_socket.cc |
| index c4bc0a533ab71a7fa5b65871e6555b51557802d5..76d50e8616578018c3273fb34120220b2aa10334 100644 |
| --- a/runtime/bin/tls_socket.cc |
| +++ b/runtime/bin/tls_socket.cc |
| @@ -11,10 +11,12 @@ |
| #include <string.h> |
| #include <nss.h> |
| +#include <pk11pub.h> |
| #include <prerror.h> |
| #include <prinit.h> |
| #include <prnetdb.h> |
| #include <ssl.h> |
| +#include <sslproto.h> |
| #include "bin/builtin.h" |
| #include "bin/dartutils.h" |
| @@ -27,6 +29,10 @@ |
| bool TlsFilter::library_initialized_ = false; |
| dart::Mutex TlsFilter::mutex_; // To protect library initialization. |
| +// The password is needed when creating secure server sockets. It can |
| +// be null if only secure client sockets are used. |
| +const char* TlsFilter::password_ = NULL; |
| + |
| static const int kTlsFilterNativeFieldIndex = 0; |
| static TlsFilter* GetTlsFilter(Dart_NativeArguments args) { |
| @@ -63,21 +69,45 @@ void FUNCTION_NAME(TlsSocket_Init)(Dart_NativeArguments args) { |
| void FUNCTION_NAME(TlsSocket_Connect)(Dart_NativeArguments args) { |
| Dart_EnterScope(); |
| - Dart_Handle host_name = ThrowIfError(Dart_GetNativeArgument(args, 1)); |
| + Dart_Handle host_name_object = ThrowIfError(Dart_GetNativeArgument(args, 1)); |
| Dart_Handle port_object = ThrowIfError(Dart_GetNativeArgument(args, 2)); |
| + Dart_Handle is_server_object = ThrowIfError(Dart_GetNativeArgument(args, 3)); |
| + Dart_Handle certificate_name_object = |
| + ThrowIfError(Dart_GetNativeArgument(args, 4)); |
| - const char* host_name_string = NULL; |
| + const char* host_name = NULL; |
| // TODO(whesse): Is truncating a Dart string containing \0 what we want? |
| - ThrowIfError(Dart_StringToCString(host_name, &host_name_string)); |
| + ThrowIfError(Dart_StringToCString(host_name_object, &host_name)); |
| int64_t port; |
| if (!DartUtils::GetInt64Value(port_object, &port) || |
| port < 0 || port > 65535) { |
| Dart_ThrowException(DartUtils::NewDartArgumentError( |
| - "Illegal port parameter in TlsSocket")); |
| + "Illegal port parameter in _TlsFilter.connect")); |
| + } |
| + |
| + if (!Dart_IsBoolean(is_server_object)) { |
| + Dart_ThrowException(DartUtils::NewDartArgumentError( |
| + "Illegal is_server parameter in _TlsFilter.connect")); |
| + } |
| + bool is_server = DartUtils::GetBooleanValue(is_server_object); |
| + |
| + const char* certificate_name = NULL; |
| + // If this is a server connection, get the certificate to connect with. |
| + // TODO(whesse): Use this parameter for a client certificate as well. |
| + if (is_server) { |
| + if (!Dart_IsString(certificate_name_object)) { |
| + Dart_ThrowException(DartUtils::NewDartArgumentError( |
| + "Non-String certificate parameter in _TlsFilter.connect")); |
| + } |
| + ThrowIfError(Dart_StringToCString(certificate_name_object, |
| + &certificate_name)); |
| } |
| - GetTlsFilter(args)->Connect(host_name_string, static_cast<int>(port)); |
| + GetTlsFilter(args)->Connect(host_name, |
| + static_cast<int>(port), |
| + is_server, |
| + certificate_name); |
| Dart_ExitScope(); |
| } |
| @@ -132,16 +162,33 @@ void FUNCTION_NAME(TlsSocket_ProcessBuffer)(Dart_NativeArguments args) { |
| void FUNCTION_NAME(TlsSocket_SetCertificateDatabase) |
| (Dart_NativeArguments args) { |
| Dart_EnterScope(); |
| - Dart_Handle dart_pkcert_dir = ThrowIfError(Dart_GetNativeArgument(args, 0)); |
| + Dart_Handle certificate_database_object = |
| + ThrowIfError(Dart_GetNativeArgument(args, 0)); |
| // Check that the type is string, and get the UTF-8 C string value from it. |
| - if (Dart_IsString(dart_pkcert_dir)) { |
| - const char* pkcert_dir = NULL; |
| - ThrowIfError(Dart_StringToCString(dart_pkcert_dir, &pkcert_dir)); |
| - TlsFilter::InitializeLibrary(pkcert_dir); |
| + const char* certificate_database = NULL; |
| + if (Dart_IsString(certificate_database_object)) { |
| + ThrowIfError(Dart_StringToCString(certificate_database_object, |
| + &certificate_database)); |
| } else { |
| Dart_ThrowException(DartUtils::NewDartArgumentError( |
| - "Non-String argument to SetCertificateDatabase")); |
| + "Non-String certificate directory argument to SetCertificateDatabase")); |
| } |
| + |
| + Dart_Handle password_object = ThrowIfError(Dart_GetNativeArgument(args, 1)); |
| + // Check that the type is string or null, |
| + // and get the UTF-8 C string value from it. |
| + const char* password = NULL; |
| + if (Dart_IsString(password_object)) { |
| + ThrowIfError(Dart_StringToCString(password_object, &password)); |
| + } else if (Dart_IsNull(password_object)) { |
| + // Leave password as NULL. |
| + password = NULL; |
|
Mads Ager (google)
2012/11/20 14:59:45
In that case this is a noop, right? So maybe leave
Bill Hesse
2012/11/20 17:46:55
It is difficult to avoid null pointer uses later o
|
| + } else { |
| + Dart_ThrowException(DartUtils::NewDartArgumentError( |
| + "Password argument to SetCertificateDatabase is not a String or null")); |
| + } |
| + |
| + TlsFilter::InitializeLibrary(certificate_database, password); |
| Dart_ExitScope(); |
| } |
| @@ -153,7 +200,7 @@ void TlsFilter::Init(Dart_Handle dart_this) { |
| Dart_NewPersistentHandle(DartUtils::NewString("length"))); |
| InitializeBuffers(dart_this); |
| - memio_ = memio_CreateIOLayer(kMemioBufferSize); |
| + filter_ = memio_CreateIOLayer(kMemioBufferSize); |
| } |
| @@ -193,12 +240,15 @@ void TlsFilter::RegisterHandshakeCompleteCallback(Dart_Handle complete) { |
| } |
| -void TlsFilter::InitializeLibrary(const char* pkcert_database) { |
| +void TlsFilter::InitializeLibrary(const char* certificate_database, |
| + const char* password) { |
| MutexLocker locker(&mutex_); |
| if (!library_initialized_) { |
| + library_initialized_ = true; |
| + password_ = strdup(password); // This one copy persists until Dart exits. |
| PR_Init(PR_USER_THREAD, PR_PRIORITY_NORMAL, 0); |
| // TODO(whesse): Verify there are no UTF-8 issues here. |
| - SECStatus status = NSS_Init(pkcert_database); |
| + SECStatus status = NSS_Init(certificate_database); |
| if (status != SECSuccess) { |
| ThrowPRException("Unsuccessful NSS_Init call."); |
| } |
| @@ -207,28 +257,77 @@ void TlsFilter::InitializeLibrary(const char* pkcert_database) { |
| if (status != SECSuccess) { |
| ThrowPRException("Unsuccessful NSS_SetDomesticPolicy call."); |
| } |
| + // Enable TLS, as well as SSL3 and SSL2. |
| + status = SSL_OptionSetDefault(SSL_ENABLE_TLS, PR_TRUE); |
| + if (status != SECSuccess) { |
| + ThrowPRException("Unsuccessful SSL_CipherPrefSetDefault RC4 call."); |
|
Mads Ager (google)
2012/11/20 14:59:45
This text does not match the call you are making.
Bill Hesse
2012/11/20 17:46:55
Done.
|
| + } |
| } else { |
| ThrowException("Called TlsFilter::InitializeLibrary more than once"); |
| } |
| } |
| +char* PasswordCallback(PK11SlotInfo* slot, PRBool retry, void* arg) { |
| + if (!retry) { |
| + return PL_strdup(static_cast<char*>(arg)); // Freed by NSS internals. |
| + } |
| + return NULL; |
| +} |
| -void TlsFilter::Connect(const char* host, int port) { |
| +void TlsFilter::Connect(const char* host_name, |
| + int port, |
| + bool is_server, |
| + const char* certificate_name) { |
| + is_server_ = is_server; |
| if (in_handshake_) { |
| ThrowException("Connect called while already in handshake state."); |
| } |
| - PRFileDesc* my_socket = memio_; |
| - my_socket = SSL_ImportFD(NULL, my_socket); |
| - if (my_socket == NULL) { |
| + filter_ = SSL_ImportFD(NULL, filter_); |
| + if (filter_ == NULL) { |
| ThrowPRException("Unsuccessful SSL_ImportFD call"); |
| } |
| - if (SSL_SetURL(my_socket, host) == -1) { |
| - ThrowPRException("Unsuccessful SetURL call"); |
| + SECStatus status; |
| + if (is_server) { |
| + CERTCertDBHandle* certificate_database = CERT_GetDefaultCertDB(); |
| + ASSERT(certificate_database != NULL); |
| + CERTCertificate* certificate = CERT_FindCertByNameString( |
| + certificate_database, |
| + const_cast<char*>(certificate_name)); |
| + ASSERT(certificate != NULL); |
| + |
| + PK11_SetPasswordFunc(PasswordCallback); |
| + |
|
Mads Ager (google)
2012/11/20 14:59:45
Remove one of these empty lines? Or maybe both of
Bill Hesse
2012/11/20 17:46:55
Both removed, and SetPasswordFunc moved up to top
|
| + |
| + SECKEYPrivateKey* key = PK11_FindKeyByAnyCert( |
| + certificate, |
| + static_cast<void*>(const_cast<char*>(password_))); |
| + if (key == NULL) { |
| + if (PR_GetError() == -8177) { |
| + ThrowPRException( |
|
Mads Ager (google)
2012/11/20 14:59:45
Indentation is off.
Bill Hesse
2012/11/20 17:46:55
Done.
|
| + "Certificate database password incorrect"); |
| + } else { |
| + ThrowPRException( |
|
Mads Ager (google)
2012/11/20 14:59:45
Indentation.
Bill Hesse
2012/11/20 17:46:55
Done.
|
| + "Unsuccessful PK11_FindKeyByAnyCert call." |
| + " Cannot find private key for certificate"); |
| + } |
| + } |
| + ASSERT(key != NULL); |
| + // kt_rsa (key type RSA) is an enum constant from the NSS libraries. |
| + // TODO(whesse): Allow different key types. |
| + status = SSL_ConfigSecureServer(filter_, certificate, key, kt_rsa); |
| + if (status != SECSuccess) { |
| + ThrowPRException("Unsuccessful SSL_ConfigSecureServer call"); |
| + } |
| + } else { // Client. |
| + if (SSL_SetURL(filter_, host_name) == -1) { |
| + ThrowPRException("Unsuccessful SetURL call"); |
| + } |
| } |
| - SECStatus status = SSL_ResetHandshake(my_socket, PR_FALSE); |
| + PRBool as_server = is_server ? PR_TRUE : PR_FALSE; // Convert bool to PRBool. |
| + status = SSL_ResetHandshake(filter_, as_server); |
| if (status != SECSuccess) { |
| ThrowPRException("Unsuccessful SSL_ResetHandshake call"); |
| } |
| @@ -237,7 +336,7 @@ void TlsFilter::Connect(const char* host, int port) { |
| PRNetAddr host_address; |
| char host_entry_buffer[PR_NETDB_BUF_SIZE]; |
| PRHostEnt host_entry; |
| - PRStatus rv = PR_GetHostByName(host, host_entry_buffer, |
| + PRStatus rv = PR_GetHostByName(host_name, host_entry_buffer, |
| PR_NETDB_BUF_SIZE, &host_entry); |
| if (rv != PR_SUCCESS) { |
| ThrowPRException("Unsuccessful PR_GetHostByName call"); |
| @@ -247,14 +346,12 @@ void TlsFilter::Connect(const char* host, int port) { |
| if (index == -1 || index == 0) { |
| ThrowPRException("Unsuccessful PR_EnumerateHostEnt call"); |
| } |
| - |
| - memio_SetPeerName(my_socket, &host_address); |
| - memio_ = my_socket; |
| + memio_SetPeerName(filter_, &host_address); |
| } |
| void TlsFilter::Handshake() { |
| - SECStatus status = SSL_ForceHandshake(memio_); |
| + SECStatus status = SSL_ForceHandshake(filter_); |
| if (status == SECSuccess) { |
| if (in_handshake_) { |
| ThrowIfError(Dart_InvokeClosure(handshake_complete_, 0, NULL)); |
| @@ -267,7 +364,11 @@ void TlsFilter::Handshake() { |
| in_handshake_ = true; |
| } |
| } else { |
| - ThrowPRException("Unexpected handshake error"); |
| + if (is_server_) { |
| + ThrowPRException("Unexpected handshake error in server"); |
| + } else { |
| + ThrowPRException("Unexpected handshake error in client"); |
| + } |
| } |
| } |
| } |
| @@ -305,7 +406,7 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) { |
| switch (buffer_index) { |
| case kReadPlaintext: { |
| int bytes_free = buffer_size_ - start - length; |
| - bytes_processed = PR_Read(memio_, |
| + bytes_processed = PR_Read(filter_, |
| buffer + start + length, |
| bytes_free); |
| if (bytes_processed < 0) { |
| @@ -326,7 +427,7 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) { |
| unsigned int len1; |
| unsigned int len2; |
| int bytes_free = buffer_size_ - start - length; |
| - memio_Private* secret = memio_GetSecret(memio_); |
| + memio_Private* secret = memio_GetSecret(filter_); |
| memio_GetWriteParams(secret, &buf1, &len1, &buf2, &len2); |
| int bytes_to_send = |
| dart::Utils::Minimum(len1, static_cast<unsigned>(bytes_free)); |
| @@ -350,11 +451,11 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) { |
| case kReadEncrypted: { |
| if (length > 0) { |
| bytes_processed = length; |
| - memio_Private* secret = memio_GetSecret(memio_); |
| - uint8_t* memio_buf; |
| - int free_bytes = memio_GetReadParams(secret, &memio_buf); |
| + memio_Private* secret = memio_GetSecret(filter_); |
| + uint8_t* filter_buf; |
| + int free_bytes = memio_GetReadParams(secret, &filter_buf); |
| if (free_bytes < bytes_processed) bytes_processed = free_bytes; |
| - memmove(memio_buf, |
| + memmove(filter_buf, |
| buffer + start, |
| bytes_processed); |
| memio_PutReadResult(secret, bytes_processed); |
| @@ -364,7 +465,7 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) { |
| case kWritePlaintext: { |
| if (length > 0) { |
| - bytes_processed = PR_Write(memio_, |
| + bytes_processed = PR_Write(filter_, |
| buffer + start, |
| length); |
| } |