| Index: runtime/bin/tls_socket.cc
|
| diff --git a/runtime/bin/tls_socket.cc b/runtime/bin/tls_socket.cc
|
| index c4bc0a533ab71a7fa5b65871e6555b51557802d5..53ee162c772f5cc55fc1eb292fa32996eb93ac41 100644
|
| --- a/runtime/bin/tls_socket.cc
|
| +++ b/runtime/bin/tls_socket.cc
|
| @@ -11,10 +11,12 @@
|
| #include <string.h>
|
|
|
| #include <nss.h>
|
| +#include <pk11pub.h>
|
| #include <prerror.h>
|
| #include <prinit.h>
|
| #include <prnetdb.h>
|
| #include <ssl.h>
|
| +#include <sslproto.h>
|
|
|
| #include "bin/builtin.h"
|
| #include "bin/dartutils.h"
|
| @@ -27,6 +29,10 @@
|
|
|
| bool TlsFilter::library_initialized_ = false;
|
| dart::Mutex TlsFilter::mutex_; // To protect library initialization.
|
| +// The password is needed when creating secure server sockets. It can
|
| +// be null if only secure client sockets are used.
|
| +const char* TlsFilter::password_ = NULL;
|
| +
|
| static const int kTlsFilterNativeFieldIndex = 0;
|
|
|
| static TlsFilter* GetTlsFilter(Dart_NativeArguments args) {
|
| @@ -63,21 +69,45 @@ void FUNCTION_NAME(TlsSocket_Init)(Dart_NativeArguments args) {
|
|
|
| void FUNCTION_NAME(TlsSocket_Connect)(Dart_NativeArguments args) {
|
| Dart_EnterScope();
|
| - Dart_Handle host_name = ThrowIfError(Dart_GetNativeArgument(args, 1));
|
| + Dart_Handle host_name_object = ThrowIfError(Dart_GetNativeArgument(args, 1));
|
| Dart_Handle port_object = ThrowIfError(Dart_GetNativeArgument(args, 2));
|
| + Dart_Handle is_server_object = ThrowIfError(Dart_GetNativeArgument(args, 3));
|
| + Dart_Handle certificate_name_object =
|
| + ThrowIfError(Dart_GetNativeArgument(args, 4));
|
|
|
| - const char* host_name_string = NULL;
|
| + const char* host_name = NULL;
|
| // TODO(whesse): Is truncating a Dart string containing \0 what we want?
|
| - ThrowIfError(Dart_StringToCString(host_name, &host_name_string));
|
| + ThrowIfError(Dart_StringToCString(host_name_object, &host_name));
|
|
|
| int64_t port;
|
| if (!DartUtils::GetInt64Value(port_object, &port) ||
|
| port < 0 || port > 65535) {
|
| Dart_ThrowException(DartUtils::NewDartArgumentError(
|
| - "Illegal port parameter in TlsSocket"));
|
| + "Illegal port parameter in _TlsFilter.connect"));
|
| + }
|
| +
|
| + if (!Dart_IsBoolean(is_server_object)) {
|
| + Dart_ThrowException(DartUtils::NewDartArgumentError(
|
| + "Illegal is_server parameter in _TlsFilter.connect"));
|
| + }
|
| + bool is_server = DartUtils::GetBooleanValue(is_server_object);
|
| +
|
| + const char* certificate_name = NULL;
|
| + // If this is a server connection, get the certificate to connect with.
|
| + // TODO(whesse): Use this parameter for a client certificate as well.
|
| + if (is_server) {
|
| + if (!Dart_IsString(certificate_name_object)) {
|
| + Dart_ThrowException(DartUtils::NewDartArgumentError(
|
| + "Non-String certificate parameter in _TlsFilter.connect"));
|
| + }
|
| + ThrowIfError(Dart_StringToCString(certificate_name_object,
|
| + &certificate_name));
|
| }
|
|
|
| - GetTlsFilter(args)->Connect(host_name_string, static_cast<int>(port));
|
| + GetTlsFilter(args)->Connect(host_name,
|
| + static_cast<int>(port),
|
| + is_server,
|
| + certificate_name);
|
| Dart_ExitScope();
|
| }
|
|
|
| @@ -132,16 +162,33 @@ void FUNCTION_NAME(TlsSocket_ProcessBuffer)(Dart_NativeArguments args) {
|
| void FUNCTION_NAME(TlsSocket_SetCertificateDatabase)
|
| (Dart_NativeArguments args) {
|
| Dart_EnterScope();
|
| - Dart_Handle dart_pkcert_dir = ThrowIfError(Dart_GetNativeArgument(args, 0));
|
| + Dart_Handle certificate_database_object =
|
| + ThrowIfError(Dart_GetNativeArgument(args, 0));
|
| // Check that the type is string, and get the UTF-8 C string value from it.
|
| - if (Dart_IsString(dart_pkcert_dir)) {
|
| - const char* pkcert_dir = NULL;
|
| - ThrowIfError(Dart_StringToCString(dart_pkcert_dir, &pkcert_dir));
|
| - TlsFilter::InitializeLibrary(pkcert_dir);
|
| + const char* certificate_database = NULL;
|
| + if (Dart_IsString(certificate_database_object)) {
|
| + ThrowIfError(Dart_StringToCString(certificate_database_object,
|
| + &certificate_database));
|
| } else {
|
| Dart_ThrowException(DartUtils::NewDartArgumentError(
|
| - "Non-String argument to SetCertificateDatabase"));
|
| + "Non-String certificate directory argument to SetCertificateDatabase"));
|
| }
|
| +
|
| + Dart_Handle password_object = ThrowIfError(Dart_GetNativeArgument(args, 1));
|
| + // Check that the type is string or null,
|
| + // and get the UTF-8 C string value from it.
|
| + const char* password = NULL;
|
| + if (Dart_IsString(password_object)) {
|
| + ThrowIfError(Dart_StringToCString(password_object, &password));
|
| + } else if (Dart_IsNull(password_object)) {
|
| + // Pass the empty string as the password.
|
| + password = "";
|
| + } else {
|
| + Dart_ThrowException(DartUtils::NewDartArgumentError(
|
| + "Password argument to SetCertificateDatabase is not a String or null"));
|
| + }
|
| +
|
| + TlsFilter::InitializeLibrary(certificate_database, password);
|
| Dart_ExitScope();
|
| }
|
|
|
| @@ -153,7 +200,7 @@ void TlsFilter::Init(Dart_Handle dart_this) {
|
| Dart_NewPersistentHandle(DartUtils::NewString("length")));
|
|
|
| InitializeBuffers(dart_this);
|
| - memio_ = memio_CreateIOLayer(kMemioBufferSize);
|
| + filter_ = memio_CreateIOLayer(kMemioBufferSize);
|
| }
|
|
|
|
|
| @@ -193,12 +240,15 @@ void TlsFilter::RegisterHandshakeCompleteCallback(Dart_Handle complete) {
|
| }
|
|
|
|
|
| -void TlsFilter::InitializeLibrary(const char* pkcert_database) {
|
| +void TlsFilter::InitializeLibrary(const char* certificate_database,
|
| + const char* password) {
|
| MutexLocker locker(&mutex_);
|
| if (!library_initialized_) {
|
| + library_initialized_ = true;
|
| + password_ = strdup(password); // This one copy persists until Dart exits.
|
| PR_Init(PR_USER_THREAD, PR_PRIORITY_NORMAL, 0);
|
| // TODO(whesse): Verify there are no UTF-8 issues here.
|
| - SECStatus status = NSS_Init(pkcert_database);
|
| + SECStatus status = NSS_Init(certificate_database);
|
| if (status != SECSuccess) {
|
| ThrowPRException("Unsuccessful NSS_Init call.");
|
| }
|
| @@ -207,28 +257,75 @@ void TlsFilter::InitializeLibrary(const char* pkcert_database) {
|
| if (status != SECSuccess) {
|
| ThrowPRException("Unsuccessful NSS_SetDomesticPolicy call.");
|
| }
|
| + // Enable TLS, as well as SSL3 and SSL2.
|
| + status = SSL_OptionSetDefault(SSL_ENABLE_TLS, PR_TRUE);
|
| + if (status != SECSuccess) {
|
| + ThrowPRException("Unsuccessful SSL_OptionSetDefault enable TLS call.");
|
| + }
|
| } else {
|
| ThrowException("Called TlsFilter::InitializeLibrary more than once");
|
| }
|
| }
|
|
|
| +char* PasswordCallback(PK11SlotInfo* slot, PRBool retry, void* arg) {
|
| + if (!retry) {
|
| + return PL_strdup(static_cast<char*>(arg)); // Freed by NSS internals.
|
| + }
|
| + return NULL;
|
| +}
|
|
|
| -void TlsFilter::Connect(const char* host, int port) {
|
| +void TlsFilter::Connect(const char* host_name,
|
| + int port,
|
| + bool is_server,
|
| + const char* certificate_name) {
|
| + is_server_ = is_server;
|
| if (in_handshake_) {
|
| ThrowException("Connect called while already in handshake state.");
|
| }
|
| - PRFileDesc* my_socket = memio_;
|
|
|
| - my_socket = SSL_ImportFD(NULL, my_socket);
|
| - if (my_socket == NULL) {
|
| + filter_ = SSL_ImportFD(NULL, filter_);
|
| + if (filter_ == NULL) {
|
| ThrowPRException("Unsuccessful SSL_ImportFD call");
|
| }
|
|
|
| - if (SSL_SetURL(my_socket, host) == -1) {
|
| - ThrowPRException("Unsuccessful SetURL call");
|
| + SECStatus status;
|
| + if (is_server) {
|
| + PK11_SetPasswordFunc(PasswordCallback);
|
| + CERTCertDBHandle* certificate_database = CERT_GetDefaultCertDB();
|
| + if (certificate_database == NULL) {
|
| + ThrowPRException("Certificate database cannot be loaded");
|
| + }
|
| + CERTCertificate* certificate = CERT_FindCertByNameString(
|
| + certificate_database,
|
| + const_cast<char*>(certificate_name));
|
| + if (certificate == NULL) {
|
| + ThrowPRException("Cannot find server certificate by name");
|
| + }
|
| + SECKEYPrivateKey* key = PK11_FindKeyByAnyCert(
|
| + certificate,
|
| + static_cast<void*>(const_cast<char*>(password_)));
|
| + if (key == NULL) {
|
| + if (PR_GetError() == -8177) {
|
| + ThrowPRException("Certificate database password incorrect");
|
| + } else {
|
| + ThrowPRException("Unsuccessful PK11_FindKeyByAnyCert call."
|
| + " Cannot find private key for certificate");
|
| + }
|
| + }
|
| + // kt_rsa (key type RSA) is an enum constant from the NSS libraries.
|
| + // TODO(whesse): Allow different key types.
|
| + status = SSL_ConfigSecureServer(filter_, certificate, key, kt_rsa);
|
| + if (status != SECSuccess) {
|
| + ThrowPRException("Unsuccessful SSL_ConfigSecureServer call");
|
| + }
|
| + } else { // Client.
|
| + if (SSL_SetURL(filter_, host_name) == -1) {
|
| + ThrowPRException("Unsuccessful SetURL call");
|
| + }
|
| }
|
|
|
| - SECStatus status = SSL_ResetHandshake(my_socket, PR_FALSE);
|
| + PRBool as_server = is_server ? PR_TRUE : PR_FALSE; // Convert bool to PRBool.
|
| + status = SSL_ResetHandshake(filter_, as_server);
|
| if (status != SECSuccess) {
|
| ThrowPRException("Unsuccessful SSL_ResetHandshake call");
|
| }
|
| @@ -237,7 +334,7 @@ void TlsFilter::Connect(const char* host, int port) {
|
| PRNetAddr host_address;
|
| char host_entry_buffer[PR_NETDB_BUF_SIZE];
|
| PRHostEnt host_entry;
|
| - PRStatus rv = PR_GetHostByName(host, host_entry_buffer,
|
| + PRStatus rv = PR_GetHostByName(host_name, host_entry_buffer,
|
| PR_NETDB_BUF_SIZE, &host_entry);
|
| if (rv != PR_SUCCESS) {
|
| ThrowPRException("Unsuccessful PR_GetHostByName call");
|
| @@ -247,14 +344,12 @@ void TlsFilter::Connect(const char* host, int port) {
|
| if (index == -1 || index == 0) {
|
| ThrowPRException("Unsuccessful PR_EnumerateHostEnt call");
|
| }
|
| -
|
| - memio_SetPeerName(my_socket, &host_address);
|
| - memio_ = my_socket;
|
| + memio_SetPeerName(filter_, &host_address);
|
| }
|
|
|
|
|
| void TlsFilter::Handshake() {
|
| - SECStatus status = SSL_ForceHandshake(memio_);
|
| + SECStatus status = SSL_ForceHandshake(filter_);
|
| if (status == SECSuccess) {
|
| if (in_handshake_) {
|
| ThrowIfError(Dart_InvokeClosure(handshake_complete_, 0, NULL));
|
| @@ -267,7 +362,11 @@ void TlsFilter::Handshake() {
|
| in_handshake_ = true;
|
| }
|
| } else {
|
| - ThrowPRException("Unexpected handshake error");
|
| + if (is_server_) {
|
| + ThrowPRException("Unexpected handshake error in server");
|
| + } else {
|
| + ThrowPRException("Unexpected handshake error in client");
|
| + }
|
| }
|
| }
|
| }
|
| @@ -305,7 +404,7 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) {
|
| switch (buffer_index) {
|
| case kReadPlaintext: {
|
| int bytes_free = buffer_size_ - start - length;
|
| - bytes_processed = PR_Read(memio_,
|
| + bytes_processed = PR_Read(filter_,
|
| buffer + start + length,
|
| bytes_free);
|
| if (bytes_processed < 0) {
|
| @@ -326,7 +425,7 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) {
|
| unsigned int len1;
|
| unsigned int len2;
|
| int bytes_free = buffer_size_ - start - length;
|
| - memio_Private* secret = memio_GetSecret(memio_);
|
| + memio_Private* secret = memio_GetSecret(filter_);
|
| memio_GetWriteParams(secret, &buf1, &len1, &buf2, &len2);
|
| int bytes_to_send =
|
| dart::Utils::Minimum(len1, static_cast<unsigned>(bytes_free));
|
| @@ -350,11 +449,11 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) {
|
| case kReadEncrypted: {
|
| if (length > 0) {
|
| bytes_processed = length;
|
| - memio_Private* secret = memio_GetSecret(memio_);
|
| - uint8_t* memio_buf;
|
| - int free_bytes = memio_GetReadParams(secret, &memio_buf);
|
| + memio_Private* secret = memio_GetSecret(filter_);
|
| + uint8_t* filter_buf;
|
| + int free_bytes = memio_GetReadParams(secret, &filter_buf);
|
| if (free_bytes < bytes_processed) bytes_processed = free_bytes;
|
| - memmove(memio_buf,
|
| + memmove(filter_buf,
|
| buffer + start,
|
| bytes_processed);
|
| memio_PutReadResult(secret, bytes_processed);
|
| @@ -364,7 +463,7 @@ intptr_t TlsFilter::ProcessBuffer(int buffer_index) {
|
|
|
| case kWritePlaintext: {
|
| if (length > 0) {
|
| - bytes_processed = PR_Write(memio_,
|
| + bytes_processed = PR_Write(filter_,
|
| buffer + start,
|
| length);
|
| }
|
|
|