Index: runtime/bin/tls_socket.cc |
diff --git a/runtime/bin/tls_socket.cc b/runtime/bin/tls_socket.cc |
index c4bc0a533ab71a7fa5b65871e6555b51557802d5..53ee162c772f5cc55fc1eb292fa32996eb93ac41 100644 |
--- a/runtime/bin/tls_socket.cc |
+++ b/runtime/bin/tls_socket.cc |
@@ -11,10 +11,12 @@ |
#include <string.h> |
#include <nss.h> |
+#include <pk11pub.h> |
#include <prerror.h> |
#include <prinit.h> |
#include <prnetdb.h> |
#include <ssl.h> |
+#include <sslproto.h> |
#include "bin/builtin.h" |
#include "bin/dartutils.h" |
@@ -27,6 +29,10 @@ |
bool TlsFilter::library_initialized_ = false; |
dart::Mutex TlsFilter::mutex_; // To protect library initialization. |
+// The password is needed when creating secure server sockets. It can |
+// be null if only secure client sockets are used. |
+const char* TlsFilter::password_ = NULL; |
+ |
static const int kTlsFilterNativeFieldIndex = 0; |
static TlsFilter* GetTlsFilter(Dart_NativeArguments args) { |
@@ -63,21 +69,45 @@ void FUNCTION_NAME(TlsSocket_Init)(Dart_NativeArguments args) { |
void FUNCTION_NAME(TlsSocket_Connect)(Dart_NativeArguments args) { |
Dart_EnterScope(); |
- Dart_Handle host_name = ThrowIfError(Dart_GetNativeArgument(args, 1)); |
+ Dart_Handle host_name_object = ThrowIfError(Dart_GetNativeArgument(args, 1)); |
Dart_Handle port_object = ThrowIfError(Dart_GetNativeArgument(args, 2)); |
+ Dart_Handle is_server_object = ThrowIfError(Dart_GetNativeArgument(args, 3)); |
+ Dart_Handle certificate_name_object = |
+ ThrowIfError(Dart_GetNativeArgument(args, 4)); |
- const char* host_name_string = NULL; |
+ const char* host_name = NULL; |
// TODO(whesse): Is truncating a Dart string containing \0 what we want? |
- ThrowIfError(Dart_StringToCString(host_name, &host_name_string)); |
+ ThrowIfError(Dart_StringToCString(host_name_object, &host_name)); |
int64_t port; |
if (!DartUtils::GetInt64Value(port_object, &port) || |
port < 0 || port > 65535) { |
Dart_ThrowException(DartUtils::NewDartArgumentError( |
- "Illegal port parameter in TlsSocket")); |
+ "Illegal port parameter in _TlsFilter.connect")); |
+ } |
+ |
+ if (!Dart_IsBoolean(is_server_object)) { |
+ Dart_ThrowException(DartUtils::NewDartArgumentError( |
+ "Illegal is_server parameter in _TlsFilter.connect")); |
+ } |
+ bool is_server = DartUtils::GetBooleanValue(is_server_object); |
+ |
+ const char* certificate_name = NULL; |
+ // If this is a server connection, get the certificate to connect with. |
+ // TODO(whesse): Use this parameter for a client certificate as well. |
+ if (is_server) { |
+ if (!Dart_IsString(certificate_name_object)) { |
+ Dart_ThrowException(DartUtils::NewDartArgumentError( |
+ "Non-String certificate parameter in _TlsFilter.connect")); |
+ } |
+ ThrowIfError(Dart_StringToCString(certificate_name_object, |
+ &certificate_name)); |
} |
- GetTlsFilter(args)->Connect(host_name_string, static_cast<int>(port)); |
+ GetTlsFilter(args)->Connect(host_name, |
+ static_cast<int>(port), |
+ is_server, |
+ certificate_name); |
Dart_ExitScope(); |
} |
@@ -132,16 +162,33 @@ void FUNCTION_NAME(TlsSocket_ProcessBuffer)(Dart_NativeArguments args) { |
void FUNCTION_NAME(TlsSocket_SetCertificateDatabase) |
(Dart_NativeArguments args) { |
Dart_EnterScope(); |
- Dart_Handle dart_pkcert_dir = ThrowIfError(Dart_GetNativeArgument(args, 0)); |
+ Dart_Handle certificate_database_object = |
+ ThrowIfError(Dart_GetNativeArgument(args, 0)); |
// Check that the type is string, and get the UTF-8 C string value from it. |
- if (Dart_IsString(dart_pkcert_dir)) { |
- const char* pkcert_dir = NULL; |
- ThrowIfError(Dart_StringToCString(dart_pkcert_dir, &pkcert_dir)); |
- TlsFilter::InitializeLibrary(pkcert_dir); |
+ const char* certificate_database = NULL; |
+ if (Dart_IsString(certificate_database_object)) { |
+ ThrowIfError(Dart_StringToCString(certificate_database_object, |
+ &certificate_database)); |
} else { |
Dart_ThrowException(DartUtils::NewDartArgumentError( |
- "Non-String argument to SetCertificateDatabase")); |
+ "Non-String certificate directory argument to SetCertificateDatabase")); |
} |
+ |
+ Dart_Handle password_object = ThrowIfError(Dart_GetNativeArgument(args, 1)); |
+ // Check that the type is string or null, |
+ // and get the UTF-8 C string value from it. |
+ const char* password = NULL; |
+ if (Dart_IsString(password_object)) { |
+ ThrowIfError(Dart_StringToCString(password_object, &password)); |
+ } else if (Dart_IsNull(password_object)) { |
+ // Pass the empty string as the password. |
+ password = ""; |
+ } else { |
+ Dart_ThrowException(DartUtils::NewDartArgumentError( |
+ "Password argument to SetCertificateDatabase is not a String or null")); |
+ } |
+ |
+ TlsFilter::InitializeLibrary(certificate_database, password); |
Dart_ExitScope(); |
} |
@@ -153,7 +200,7 @@ void TlsFilter::Init(Dart_Handle dart_this) { |
Dart_NewPersistentHandle(DartUtils::NewString("length"))); |
InitializeBuffers(dart_this); |
- memio_ = memio_CreateIOLayer(kMemioBufferSize); |
+ filter_ = memio_CreateIOLayer(kMemioBufferSize); |
} |
@@ -193,12 +240,15 @@ void TlsFilter::RegisterHandshakeCompleteCallback(Dart_Handle complete) { |
} |
-void TlsFilter::InitializeLibrary(const char* pkcert_database) { |
+void TlsFilter::InitializeLibrary(const char* certificate_database, |
+ const char* password) { |
MutexLocker locker(&mutex_); |
if (!library_initialized_) { |
+ library_initialized_ = true; |
+ password_ = strdup(password); // This one copy persists until Dart exits. |
PR_Init(PR_USER_THREAD, PR_PRIORITY_NORMAL, 0); |
// TODO(whesse): Verify there are no UTF-8 issues here. |
- SECStatus status = NSS_Init(pkcert_database); |
+ SECStatus status = NSS_Init(certificate_database); |
if (status != SECSuccess) { |
ThrowPRException("Unsuccessful NSS_Init call."); |
} |
@@ -207,28 +257,75 @@ void TlsFilter::InitializeLibrary(const char* pkcert_database) { |
if (status != SECSuccess) { |
ThrowPRException("Unsuccessful NSS_SetDomesticPolicy call."); |
} |
+ // Enable TLS, as well as SSL3 and SSL2. |
+ status = SSL_OptionSetDefault(SSL_ENABLE_TLS, PR_TRUE); |
+ if (status != SECSuccess) { |
+ ThrowPRException("Unsuccessful SSL_OptionSetDefault enable TLS call."); |
+ } |
} else { |
ThrowException("Called TlsFilter::InitializeLibrary more than once"); |
} |
} |
+char* PasswordCallback(PK11SlotInfo* slot, PRBool retry, void* arg) { |
+ if (!retry) { |
+ return PL_strdup(static_cast<char*>(arg)); // Freed by NSS internals. |
+ } |
+ return NULL; |
+} |
-void TlsFilter::Connect(const char* host, int port) { |
+void TlsFilter::Connect(const char* host_name, |
+ int port, |
+ bool is_server, |
+ const char* certificate_name) { |
+ is_server_ = is_server; |
if (in_handshake_) { |
ThrowException("Connect called while already in handshake state."); |
} |
- PRFileDesc* my_socket = memio_; |
- my_socket = SSL_ImportFD(NULL, my_socket); |
- if (my_socket == NULL) { |
+ filter_ = SSL_ImportFD(NULL, filter_); |
+ if (filter_ == NULL) { |
ThrowPRException("Unsuccessful SSL_ImportFD call"); |
} |
- if (SSL_SetURL(my_socket, host) == -1) { |
- ThrowPRException("Unsuccessful SetURL call"); |
+ SECStatus status; |
+ if (is_server) { |
+ PK11_SetPasswordFunc(PasswordCallback); |
+ CERTCertDBHandle* certificate_database = CERT_GetDefaultCertDB(); |
+ if (certificate_database == NULL) { |
+ ThrowPRException("Certificate database cannot be loaded"); |
+ } |
+ CERTCertificate* certificate = CERT_FindCertByNameString( |
+ certificate_database, |
+ const_cast<char*>(certificate_name)); |
+ if (certificate == NULL) { |
+ ThrowPRException("Cannot find server certificate by name"); |
+ } |
+ SECKEYPrivateKey* key = PK11_FindKeyByAnyCert( |
+ certificate, |
+ static_cast<void*>(const_cast<char*>(password_))); |
+ if (key == NULL) { |
+ if (PR_GetError() == -8177) { |
+ ThrowPRException("Certificate database password incorrect"); |
+ } else { |
+ ThrowPRException("Unsuccessful PK11_FindKeyByAnyCert call." |
+ " Cannot find private key for certificate"); |
+ } |
+ } |
+ // kt_rsa (key type RSA) is an enum constant from the NSS libraries. |
+ // TODO(whesse): Allow different key types. |
+ status = SSL_ConfigSecureServer(filter_, certificate, key, kt_rsa); |
+ if (status != SECSuccess) { |
+ ThrowPRException("Unsuccessful SSL_ConfigSecureServer call"); |
+ } |
+ } else { // Client. |
+ if (SSL_SetURL(filter_, host_name) == -1) { |
+ ThrowPRException("Unsuccessful SetURL call"); |
+ } |
} |
- SECStatus status = SSL_ResetHandshake(my_socket, PR_FALSE); |
+ PRBool as_server = is_server ? PR_TRUE : PR_FALSE; // Convert bool to PRBool. |
+ status = SSL_ResetHandshake(filter_, as_server); |
if (status != SECSuccess) { |
ThrowPRException("Unsuccessful SSL_ResetHandshake call"); |
} |
@@ -237,7 +334,7 @@ void TlsFilter::Connect(const char* host, int port) { |
PRNetAddr host_address; |
char host_entry_buffer[PR_NETDB_BUF_SIZE]; |
PRHostEnt host_entry; |
- PRStatus rv = PR_GetHostByName(host, host_entry_buffer, |
+ PRStatus rv = PR_GetHostByName(host_name, host_entry_buffer, |
PR_NETDB_BUF_SIZE, &host_entry); |
if (rv != PR_SUCCESS) { |
ThrowPRException("Unsuccessful PR_GetHostByName call"); |
@@ -247,14 +344,12 @@ void TlsFilter::Connect(const char* host, int port) { |
if (index == -1 || index == 0) { |
ThrowPRException("Unsuccessful PR_EnumerateHostEnt call"); |
} |
- |
- memio_SetPeerName(my_socket, &host_address); |
- memio_ = my_socket; |
+ memio_SetPeerName(filter_, &host_address); |
} |
void TlsFilter::Handshake() { |
- SECStatus status = SSL_ForceHandshake(memio_); |
+ SECStatus status = SSL_ForceHandshake(filter_); |
if (status == SECSuccess) { |
if (in_handshake_) { |
ThrowIfError(Dart_InvokeClosure(handshake_complete_, 0, NULL)); |
@@ -267,7 +362,11 @@ void TlsFilter::Handshake() { |
in_handshake_ = true; |
} |
} else { |
- ThrowPRException("Unexpected handshake error"); |
+ if (is_server_) { |
+ ThrowPRException("Unexpected handshake error in server"); |
+ } else { |
+ ThrowPRException("Unexpected handshake error in client"); |
+ } |
} |
} |
} |
@@ -305,7 +404,7 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) { |
switch (buffer_index) { |
case kReadPlaintext: { |
int bytes_free = buffer_size_ - start - length; |
- bytes_processed = PR_Read(memio_, |
+ bytes_processed = PR_Read(filter_, |
buffer + start + length, |
bytes_free); |
if (bytes_processed < 0) { |
@@ -326,7 +425,7 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) { |
unsigned int len1; |
unsigned int len2; |
int bytes_free = buffer_size_ - start - length; |
- memio_Private* secret = memio_GetSecret(memio_); |
+ memio_Private* secret = memio_GetSecret(filter_); |
memio_GetWriteParams(secret, &buf1, &len1, &buf2, &len2); |
int bytes_to_send = |
dart::Utils::Minimum(len1, static_cast<unsigned>(bytes_free)); |
@@ -350,11 +449,11 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) { |
case kReadEncrypted: { |
if (length > 0) { |
bytes_processed = length; |
- memio_Private* secret = memio_GetSecret(memio_); |
- uint8_t* memio_buf; |
- int free_bytes = memio_GetReadParams(secret, &memio_buf); |
+ memio_Private* secret = memio_GetSecret(filter_); |
+ uint8_t* filter_buf; |
+ int free_bytes = memio_GetReadParams(secret, &filter_buf); |
if (free_bytes < bytes_processed) bytes_processed = free_bytes; |
- memmove(memio_buf, |
+ memmove(filter_buf, |
buffer + start, |
bytes_processed); |
memio_PutReadResult(secret, bytes_processed); |
@@ -364,7 +463,7 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) { |
case kWritePlaintext: { |
if (length > 0) { |
- bytes_processed = PR_Write(memio_, |
+ bytes_processed = PR_Write(filter_, |
buffer + start, |
length); |
} |