| OLD | NEW |
| 1 // Copyright (c) 2012, the Dart project authors. Please see the AUTHORS file | 1 // Copyright (c) 2012, the Dart project authors. Please see the AUTHORS file |
| 2 // for details. All rights reserved. Use of this source code is governed by a | 2 // for details. All rights reserved. Use of this source code is governed by a |
| 3 // BSD-style license that can be found in the LICENSE file. | 3 // BSD-style license that can be found in the LICENSE file. |
| 4 | 4 |
| 5 #include "bin/tls_socket.h" | 5 #include "bin/tls_socket.h" |
| 6 | 6 |
| 7 #include <errno.h> | 7 #include <errno.h> |
| 8 #include <fcntl.h> | 8 #include <fcntl.h> |
| 9 #include <sys/stat.h> | 9 #include <sys/stat.h> |
| 10 #include <stdio.h> | 10 #include <stdio.h> |
| 11 #include <string.h> | 11 #include <string.h> |
| 12 | 12 |
| 13 #include <nss.h> | 13 #include <nss.h> |
| 14 #include <pk11pub.h> |
| 14 #include <prerror.h> | 15 #include <prerror.h> |
| 15 #include <prinit.h> | 16 #include <prinit.h> |
| 16 #include <prnetdb.h> | 17 #include <prnetdb.h> |
| 17 #include <ssl.h> | 18 #include <ssl.h> |
| 19 #include <sslproto.h> |
| 18 | 20 |
| 19 #include "bin/builtin.h" | 21 #include "bin/builtin.h" |
| 20 #include "bin/dartutils.h" | 22 #include "bin/dartutils.h" |
| 21 #include "bin/net/nss_memio.h" | 23 #include "bin/net/nss_memio.h" |
| 22 #include "bin/thread.h" | 24 #include "bin/thread.h" |
| 23 #include "bin/utils.h" | 25 #include "bin/utils.h" |
| 24 #include "platform/utils.h" | 26 #include "platform/utils.h" |
| 25 | 27 |
| 26 #include "include/dart_api.h" | 28 #include "include/dart_api.h" |
| 27 | 29 |
| 28 bool TlsFilter::library_initialized_ = false; | 30 bool TlsFilter::library_initialized_ = false; |
| 29 dart::Mutex TlsFilter::mutex_; // To protect library initialization. | 31 dart::Mutex TlsFilter::mutex_; // To protect library initialization. |
| 32 // The password is needed when creating secure server sockets. It can |
| 33 // be null if only secure client sockets are used. |
| 34 const char* TlsFilter::password_ = NULL; |
| 35 |
| 30 static const int kTlsFilterNativeFieldIndex = 0; | 36 static const int kTlsFilterNativeFieldIndex = 0; |
| 31 | 37 |
| 32 static TlsFilter* GetTlsFilter(Dart_NativeArguments args) { | 38 static TlsFilter* GetTlsFilter(Dart_NativeArguments args) { |
| 33 TlsFilter* filter; | 39 TlsFilter* filter; |
| 34 Dart_Handle dart_this = ThrowIfError(Dart_GetNativeArgument(args, 0)); | 40 Dart_Handle dart_this = ThrowIfError(Dart_GetNativeArgument(args, 0)); |
| 35 ASSERT(Dart_IsInstance(dart_this)); | 41 ASSERT(Dart_IsInstance(dart_this)); |
| 36 ThrowIfError(Dart_GetNativeInstanceField( | 42 ThrowIfError(Dart_GetNativeInstanceField( |
| 37 dart_this, | 43 dart_this, |
| 38 kTlsFilterNativeFieldIndex, | 44 kTlsFilterNativeFieldIndex, |
| 39 reinterpret_cast<intptr_t*>(&filter))); | 45 reinterpret_cast<intptr_t*>(&filter))); |
| (...skipping 16 matching lines...) Expand all Loading... |
| 56 Dart_Handle dart_this = ThrowIfError(Dart_GetNativeArgument(args, 0)); | 62 Dart_Handle dart_this = ThrowIfError(Dart_GetNativeArgument(args, 0)); |
| 57 TlsFilter* filter = new TlsFilter; | 63 TlsFilter* filter = new TlsFilter; |
| 58 SetTlsFilter(args, filter); | 64 SetTlsFilter(args, filter); |
| 59 filter->Init(dart_this); | 65 filter->Init(dart_this); |
| 60 Dart_ExitScope(); | 66 Dart_ExitScope(); |
| 61 } | 67 } |
| 62 | 68 |
| 63 | 69 |
| 64 void FUNCTION_NAME(TlsSocket_Connect)(Dart_NativeArguments args) { | 70 void FUNCTION_NAME(TlsSocket_Connect)(Dart_NativeArguments args) { |
| 65 Dart_EnterScope(); | 71 Dart_EnterScope(); |
| 66 Dart_Handle host_name = ThrowIfError(Dart_GetNativeArgument(args, 1)); | 72 Dart_Handle host_name_object = ThrowIfError(Dart_GetNativeArgument(args, 1)); |
| 67 Dart_Handle port_object = ThrowIfError(Dart_GetNativeArgument(args, 2)); | 73 Dart_Handle port_object = ThrowIfError(Dart_GetNativeArgument(args, 2)); |
| 74 Dart_Handle is_server_object = ThrowIfError(Dart_GetNativeArgument(args, 3)); |
| 75 Dart_Handle certificate_name_object = |
| 76 ThrowIfError(Dart_GetNativeArgument(args, 4)); |
| 68 | 77 |
| 69 const char* host_name_string = NULL; | 78 const char* host_name = NULL; |
| 70 // TODO(whesse): Is truncating a Dart string containing \0 what we want? | 79 // TODO(whesse): Is truncating a Dart string containing \0 what we want? |
| 71 ThrowIfError(Dart_StringToCString(host_name, &host_name_string)); | 80 ThrowIfError(Dart_StringToCString(host_name_object, &host_name)); |
| 72 | 81 |
| 73 int64_t port; | 82 int64_t port; |
| 74 if (!DartUtils::GetInt64Value(port_object, &port) || | 83 if (!DartUtils::GetInt64Value(port_object, &port) || |
| 75 port < 0 || port > 65535) { | 84 port < 0 || port > 65535) { |
| 76 Dart_ThrowException(DartUtils::NewDartArgumentError( | 85 Dart_ThrowException(DartUtils::NewDartArgumentError( |
| 77 "Illegal port parameter in TlsSocket")); | 86 "Illegal port parameter in _TlsFilter.connect")); |
| 78 } | 87 } |
| 79 | 88 |
| 80 GetTlsFilter(args)->Connect(host_name_string, static_cast<int>(port)); | 89 if (!Dart_IsBoolean(is_server_object)) { |
| 90 Dart_ThrowException(DartUtils::NewDartArgumentError( |
| 91 "Illegal is_server parameter in _TlsFilter.connect")); |
| 92 } |
| 93 bool is_server = DartUtils::GetBooleanValue(is_server_object); |
| 94 |
| 95 const char* certificate_name = NULL; |
| 96 // If this is a server connection, get the certificate to connect with. |
| 97 // TODO(whesse): Use this parameter for a client certificate as well. |
| 98 if (is_server) { |
| 99 if (!Dart_IsString(certificate_name_object)) { |
| 100 Dart_ThrowException(DartUtils::NewDartArgumentError( |
| 101 "Non-String certificate parameter in _TlsFilter.connect")); |
| 102 } |
| 103 ThrowIfError(Dart_StringToCString(certificate_name_object, |
| 104 &certificate_name)); |
| 105 } |
| 106 |
| 107 GetTlsFilter(args)->Connect(host_name, |
| 108 static_cast<int>(port), |
| 109 is_server, |
| 110 certificate_name); |
| 81 Dart_ExitScope(); | 111 Dart_ExitScope(); |
| 82 } | 112 } |
| 83 | 113 |
| 84 | 114 |
| 85 void FUNCTION_NAME(TlsSocket_Destroy)(Dart_NativeArguments args) { | 115 void FUNCTION_NAME(TlsSocket_Destroy)(Dart_NativeArguments args) { |
| 86 Dart_EnterScope(); | 116 Dart_EnterScope(); |
| 87 TlsFilter* filter = GetTlsFilter(args); | 117 TlsFilter* filter = GetTlsFilter(args); |
| 88 SetTlsFilter(args, NULL); | 118 SetTlsFilter(args, NULL); |
| 89 filter->Destroy(); | 119 filter->Destroy(); |
| 90 delete filter; | 120 delete filter; |
| (...skipping 34 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 125 intptr_t bytes_read = | 155 intptr_t bytes_read = |
| 126 GetTlsFilter(args)->ProcessBuffer(static_cast<int>(buffer_id)); | 156 GetTlsFilter(args)->ProcessBuffer(static_cast<int>(buffer_id)); |
| 127 Dart_SetReturnValue(args, Dart_NewInteger(bytes_read)); | 157 Dart_SetReturnValue(args, Dart_NewInteger(bytes_read)); |
| 128 Dart_ExitScope(); | 158 Dart_ExitScope(); |
| 129 } | 159 } |
| 130 | 160 |
| 131 | 161 |
| 132 void FUNCTION_NAME(TlsSocket_SetCertificateDatabase) | 162 void FUNCTION_NAME(TlsSocket_SetCertificateDatabase) |
| 133 (Dart_NativeArguments args) { | 163 (Dart_NativeArguments args) { |
| 134 Dart_EnterScope(); | 164 Dart_EnterScope(); |
| 135 Dart_Handle dart_pkcert_dir = ThrowIfError(Dart_GetNativeArgument(args, 0)); | 165 Dart_Handle certificate_database_object = |
| 166 ThrowIfError(Dart_GetNativeArgument(args, 0)); |
| 136 // Check that the type is string, and get the UTF-8 C string value from it. | 167 // Check that the type is string, and get the UTF-8 C string value from it. |
| 137 if (Dart_IsString(dart_pkcert_dir)) { | 168 const char* certificate_database = NULL; |
| 138 const char* pkcert_dir = NULL; | 169 if (Dart_IsString(certificate_database_object)) { |
| 139 ThrowIfError(Dart_StringToCString(dart_pkcert_dir, &pkcert_dir)); | 170 ThrowIfError(Dart_StringToCString(certificate_database_object, |
| 140 TlsFilter::InitializeLibrary(pkcert_dir); | 171 &certificate_database)); |
| 141 } else { | 172 } else { |
| 142 Dart_ThrowException(DartUtils::NewDartArgumentError( | 173 Dart_ThrowException(DartUtils::NewDartArgumentError( |
| 143 "Non-String argument to SetCertificateDatabase")); | 174 "Non-String certificate directory argument to SetCertificateDatabase")); |
| 144 } | 175 } |
| 176 |
| 177 Dart_Handle password_object = ThrowIfError(Dart_GetNativeArgument(args, 1)); |
| 178 // Check that the type is string or null, |
| 179 // and get the UTF-8 C string value from it. |
| 180 const char* password = NULL; |
| 181 if (Dart_IsString(password_object)) { |
| 182 ThrowIfError(Dart_StringToCString(password_object, &password)); |
| 183 } else if (Dart_IsNull(password_object)) { |
| 184 // Pass the empty string as the password. |
| 185 password = ""; |
| 186 } else { |
| 187 Dart_ThrowException(DartUtils::NewDartArgumentError( |
| 188 "Password argument to SetCertificateDatabase is not a String or null")); |
| 189 } |
| 190 |
| 191 TlsFilter::InitializeLibrary(certificate_database, password); |
| 145 Dart_ExitScope(); | 192 Dart_ExitScope(); |
| 146 } | 193 } |
| 147 | 194 |
| 148 | 195 |
| 149 void TlsFilter::Init(Dart_Handle dart_this) { | 196 void TlsFilter::Init(Dart_Handle dart_this) { |
| 150 string_start_ = ThrowIfError( | 197 string_start_ = ThrowIfError( |
| 151 Dart_NewPersistentHandle(DartUtils::NewString("start"))); | 198 Dart_NewPersistentHandle(DartUtils::NewString("start"))); |
| 152 string_length_ = ThrowIfError( | 199 string_length_ = ThrowIfError( |
| 153 Dart_NewPersistentHandle(DartUtils::NewString("length"))); | 200 Dart_NewPersistentHandle(DartUtils::NewString("length"))); |
| 154 | 201 |
| 155 InitializeBuffers(dart_this); | 202 InitializeBuffers(dart_this); |
| 156 memio_ = memio_CreateIOLayer(kMemioBufferSize); | 203 filter_ = memio_CreateIOLayer(kMemioBufferSize); |
| 157 } | 204 } |
| 158 | 205 |
| 159 | 206 |
| 160 void TlsFilter::InitializeBuffers(Dart_Handle dart_this) { | 207 void TlsFilter::InitializeBuffers(Dart_Handle dart_this) { |
| 161 // Create TlsFilter buffers as ExternalUint8Array objects. | 208 // Create TlsFilter buffers as ExternalUint8Array objects. |
| 162 Dart_Handle dart_buffers_object = ThrowIfError( | 209 Dart_Handle dart_buffers_object = ThrowIfError( |
| 163 Dart_GetField(dart_this, DartUtils::NewString("buffers"))); | 210 Dart_GetField(dart_this, DartUtils::NewString("buffers"))); |
| 164 Dart_Handle dart_buffer_object = | 211 Dart_Handle dart_buffer_object = |
| 165 Dart_ListGetAt(dart_buffers_object, kReadPlaintext); | 212 Dart_ListGetAt(dart_buffers_object, kReadPlaintext); |
| 166 Dart_Handle tls_external_buffer_class = | 213 Dart_Handle tls_external_buffer_class = |
| (...skipping 19 matching lines...) Expand all Loading... |
| 186 } | 233 } |
| 187 } | 234 } |
| 188 | 235 |
| 189 | 236 |
| 190 void TlsFilter::RegisterHandshakeCompleteCallback(Dart_Handle complete) { | 237 void TlsFilter::RegisterHandshakeCompleteCallback(Dart_Handle complete) { |
| 191 ASSERT(NULL == handshake_complete_); | 238 ASSERT(NULL == handshake_complete_); |
| 192 handshake_complete_ = ThrowIfError(Dart_NewPersistentHandle(complete)); | 239 handshake_complete_ = ThrowIfError(Dart_NewPersistentHandle(complete)); |
| 193 } | 240 } |
| 194 | 241 |
| 195 | 242 |
| 196 void TlsFilter::InitializeLibrary(const char* pkcert_database) { | 243 void TlsFilter::InitializeLibrary(const char* certificate_database, |
| 244 const char* password) { |
| 197 MutexLocker locker(&mutex_); | 245 MutexLocker locker(&mutex_); |
| 198 if (!library_initialized_) { | 246 if (!library_initialized_) { |
| 247 library_initialized_ = true; |
| 248 password_ = strdup(password); // This one copy persists until Dart exits. |
| 199 PR_Init(PR_USER_THREAD, PR_PRIORITY_NORMAL, 0); | 249 PR_Init(PR_USER_THREAD, PR_PRIORITY_NORMAL, 0); |
| 200 // TODO(whesse): Verify there are no UTF-8 issues here. | 250 // TODO(whesse): Verify there are no UTF-8 issues here. |
| 201 SECStatus status = NSS_Init(pkcert_database); | 251 SECStatus status = NSS_Init(certificate_database); |
| 202 if (status != SECSuccess) { | 252 if (status != SECSuccess) { |
| 203 ThrowPRException("Unsuccessful NSS_Init call."); | 253 ThrowPRException("Unsuccessful NSS_Init call."); |
| 204 } | 254 } |
| 205 | 255 |
| 206 status = NSS_SetDomesticPolicy(); | 256 status = NSS_SetDomesticPolicy(); |
| 207 if (status != SECSuccess) { | 257 if (status != SECSuccess) { |
| 208 ThrowPRException("Unsuccessful NSS_SetDomesticPolicy call."); | 258 ThrowPRException("Unsuccessful NSS_SetDomesticPolicy call."); |
| 209 } | 259 } |
| 260 // Enable TLS, as well as SSL3 and SSL2. |
| 261 status = SSL_OptionSetDefault(SSL_ENABLE_TLS, PR_TRUE); |
| 262 if (status != SECSuccess) { |
| 263 ThrowPRException("Unsuccessful SSL_OptionSetDefault enable TLS call."); |
| 264 } |
| 210 } else { | 265 } else { |
| 211 ThrowException("Called TlsFilter::InitializeLibrary more than once"); | 266 ThrowException("Called TlsFilter::InitializeLibrary more than once"); |
| 212 } | 267 } |
| 213 } | 268 } |
| 214 | 269 |
| 270 char* PasswordCallback(PK11SlotInfo* slot, PRBool retry, void* arg) { |
| 271 if (!retry) { |
| 272 return PL_strdup(static_cast<char*>(arg)); // Freed by NSS internals. |
| 273 } |
| 274 return NULL; |
| 275 } |
| 215 | 276 |
| 216 void TlsFilter::Connect(const char* host, int port) { | 277 void TlsFilter::Connect(const char* host_name, |
| 278 int port, |
| 279 bool is_server, |
| 280 const char* certificate_name) { |
| 281 is_server_ = is_server; |
| 217 if (in_handshake_) { | 282 if (in_handshake_) { |
| 218 ThrowException("Connect called while already in handshake state."); | 283 ThrowException("Connect called while already in handshake state."); |
| 219 } | 284 } |
| 220 PRFileDesc* my_socket = memio_; | |
| 221 | 285 |
| 222 my_socket = SSL_ImportFD(NULL, my_socket); | 286 filter_ = SSL_ImportFD(NULL, filter_); |
| 223 if (my_socket == NULL) { | 287 if (filter_ == NULL) { |
| 224 ThrowPRException("Unsuccessful SSL_ImportFD call"); | 288 ThrowPRException("Unsuccessful SSL_ImportFD call"); |
| 225 } | 289 } |
| 226 | 290 |
| 227 if (SSL_SetURL(my_socket, host) == -1) { | 291 SECStatus status; |
| 228 ThrowPRException("Unsuccessful SetURL call"); | 292 if (is_server) { |
| 293 PK11_SetPasswordFunc(PasswordCallback); |
| 294 CERTCertDBHandle* certificate_database = CERT_GetDefaultCertDB(); |
| 295 if (certificate_database == NULL) { |
| 296 ThrowPRException("Certificate database cannot be loaded"); |
| 297 } |
| 298 CERTCertificate* certificate = CERT_FindCertByNameString( |
| 299 certificate_database, |
| 300 const_cast<char*>(certificate_name)); |
| 301 if (certificate == NULL) { |
| 302 ThrowPRException("Cannot find server certificate by name"); |
| 303 } |
| 304 SECKEYPrivateKey* key = PK11_FindKeyByAnyCert( |
| 305 certificate, |
| 306 static_cast<void*>(const_cast<char*>(password_))); |
| 307 if (key == NULL) { |
| 308 if (PR_GetError() == -8177) { |
| 309 ThrowPRException("Certificate database password incorrect"); |
| 310 } else { |
| 311 ThrowPRException("Unsuccessful PK11_FindKeyByAnyCert call." |
| 312 " Cannot find private key for certificate"); |
| 313 } |
| 314 } |
| 315 // kt_rsa (key type RSA) is an enum constant from the NSS libraries. |
| 316 // TODO(whesse): Allow different key types. |
| 317 status = SSL_ConfigSecureServer(filter_, certificate, key, kt_rsa); |
| 318 if (status != SECSuccess) { |
| 319 ThrowPRException("Unsuccessful SSL_ConfigSecureServer call"); |
| 320 } |
| 321 } else { // Client. |
| 322 if (SSL_SetURL(filter_, host_name) == -1) { |
| 323 ThrowPRException("Unsuccessful SetURL call"); |
| 324 } |
| 229 } | 325 } |
| 230 | 326 |
| 231 SECStatus status = SSL_ResetHandshake(my_socket, PR_FALSE); | 327 PRBool as_server = is_server ? PR_TRUE : PR_FALSE; // Convert bool to PRBool. |
| 328 status = SSL_ResetHandshake(filter_, as_server); |
| 232 if (status != SECSuccess) { | 329 if (status != SECSuccess) { |
| 233 ThrowPRException("Unsuccessful SSL_ResetHandshake call"); | 330 ThrowPRException("Unsuccessful SSL_ResetHandshake call"); |
| 234 } | 331 } |
| 235 | 332 |
| 236 // SetPeerAddress | 333 // SetPeerAddress |
| 237 PRNetAddr host_address; | 334 PRNetAddr host_address; |
| 238 char host_entry_buffer[PR_NETDB_BUF_SIZE]; | 335 char host_entry_buffer[PR_NETDB_BUF_SIZE]; |
| 239 PRHostEnt host_entry; | 336 PRHostEnt host_entry; |
| 240 PRStatus rv = PR_GetHostByName(host, host_entry_buffer, | 337 PRStatus rv = PR_GetHostByName(host_name, host_entry_buffer, |
| 241 PR_NETDB_BUF_SIZE, &host_entry); | 338 PR_NETDB_BUF_SIZE, &host_entry); |
| 242 if (rv != PR_SUCCESS) { | 339 if (rv != PR_SUCCESS) { |
| 243 ThrowPRException("Unsuccessful PR_GetHostByName call"); | 340 ThrowPRException("Unsuccessful PR_GetHostByName call"); |
| 244 } | 341 } |
| 245 | 342 |
| 246 int index = PR_EnumerateHostEnt(0, &host_entry, port, &host_address); | 343 int index = PR_EnumerateHostEnt(0, &host_entry, port, &host_address); |
| 247 if (index == -1 || index == 0) { | 344 if (index == -1 || index == 0) { |
| 248 ThrowPRException("Unsuccessful PR_EnumerateHostEnt call"); | 345 ThrowPRException("Unsuccessful PR_EnumerateHostEnt call"); |
| 249 } | 346 } |
| 250 | 347 memio_SetPeerName(filter_, &host_address); |
| 251 memio_SetPeerName(my_socket, &host_address); | |
| 252 memio_ = my_socket; | |
| 253 } | 348 } |
| 254 | 349 |
| 255 | 350 |
| 256 void TlsFilter::Handshake() { | 351 void TlsFilter::Handshake() { |
| 257 SECStatus status = SSL_ForceHandshake(memio_); | 352 SECStatus status = SSL_ForceHandshake(filter_); |
| 258 if (status == SECSuccess) { | 353 if (status == SECSuccess) { |
| 259 if (in_handshake_) { | 354 if (in_handshake_) { |
| 260 ThrowIfError(Dart_InvokeClosure(handshake_complete_, 0, NULL)); | 355 ThrowIfError(Dart_InvokeClosure(handshake_complete_, 0, NULL)); |
| 261 in_handshake_ = false; | 356 in_handshake_ = false; |
| 262 } | 357 } |
| 263 } else { | 358 } else { |
| 264 PRErrorCode error = PR_GetError(); | 359 PRErrorCode error = PR_GetError(); |
| 265 if (error == PR_WOULD_BLOCK_ERROR) { | 360 if (error == PR_WOULD_BLOCK_ERROR) { |
| 266 if (!in_handshake_) { | 361 if (!in_handshake_) { |
| 267 in_handshake_ = true; | 362 in_handshake_ = true; |
| 268 } | 363 } |
| 269 } else { | 364 } else { |
| 270 ThrowPRException("Unexpected handshake error"); | 365 if (is_server_) { |
| 366 ThrowPRException("Unexpected handshake error in server"); |
| 367 } else { |
| 368 ThrowPRException("Unexpected handshake error in client"); |
| 369 } |
| 271 } | 370 } |
| 272 } | 371 } |
| 273 } | 372 } |
| 274 | 373 |
| 275 | 374 |
| 276 void TlsFilter::Destroy() { | 375 void TlsFilter::Destroy() { |
| 277 for (int i = 0; i < kNumBuffers; ++i) { | 376 for (int i = 0; i < kNumBuffers; ++i) { |
| 278 Dart_DeletePersistentHandle(dart_buffer_objects_[i]); | 377 Dart_DeletePersistentHandle(dart_buffer_objects_[i]); |
| 279 delete[] buffers_[i]; | 378 delete[] buffers_[i]; |
| 280 } | 379 } |
| (...skipping 17 matching lines...) Expand all Loading... |
| 298 ASSERT(unsafe_length >= 0); | 397 ASSERT(unsafe_length >= 0); |
| 299 ASSERT(unsafe_length <= buffer_size_); | 398 ASSERT(unsafe_length <= buffer_size_); |
| 300 intptr_t start = static_cast<intptr_t>(unsafe_start); | 399 intptr_t start = static_cast<intptr_t>(unsafe_start); |
| 301 intptr_t length = static_cast<intptr_t>(unsafe_length); | 400 intptr_t length = static_cast<intptr_t>(unsafe_length); |
| 302 uint8_t* buffer = buffers_[buffer_index]; | 401 uint8_t* buffer = buffers_[buffer_index]; |
| 303 | 402 |
| 304 int bytes_processed = 0; | 403 int bytes_processed = 0; |
| 305 switch (buffer_index) { | 404 switch (buffer_index) { |
| 306 case kReadPlaintext: { | 405 case kReadPlaintext: { |
| 307 int bytes_free = buffer_size_ - start - length; | 406 int bytes_free = buffer_size_ - start - length; |
| 308 bytes_processed = PR_Read(memio_, | 407 bytes_processed = PR_Read(filter_, |
| 309 buffer + start + length, | 408 buffer + start + length, |
| 310 bytes_free); | 409 bytes_free); |
| 311 if (bytes_processed < 0) { | 410 if (bytes_processed < 0) { |
| 312 ASSERT(bytes_processed == -1); | 411 ASSERT(bytes_processed == -1); |
| 313 // TODO(whesse): Handle unexpected errors here. | 412 // TODO(whesse): Handle unexpected errors here. |
| 314 PRErrorCode pr_error = PR_GetError(); | 413 PRErrorCode pr_error = PR_GetError(); |
| 315 if (PR_WOULD_BLOCK_ERROR != pr_error) { | 414 if (PR_WOULD_BLOCK_ERROR != pr_error) { |
| 316 ThrowPRException("Error reading plaintext from TlsFilter"); | 415 ThrowPRException("Error reading plaintext from TlsFilter"); |
| 317 } | 416 } |
| 318 bytes_processed = 0; | 417 bytes_processed = 0; |
| 319 } | 418 } |
| 320 break; | 419 break; |
| 321 } | 420 } |
| 322 | 421 |
| 323 case kWriteEncrypted: { | 422 case kWriteEncrypted: { |
| 324 const uint8_t* buf1; | 423 const uint8_t* buf1; |
| 325 const uint8_t* buf2; | 424 const uint8_t* buf2; |
| 326 unsigned int len1; | 425 unsigned int len1; |
| 327 unsigned int len2; | 426 unsigned int len2; |
| 328 int bytes_free = buffer_size_ - start - length; | 427 int bytes_free = buffer_size_ - start - length; |
| 329 memio_Private* secret = memio_GetSecret(memio_); | 428 memio_Private* secret = memio_GetSecret(filter_); |
| 330 memio_GetWriteParams(secret, &buf1, &len1, &buf2, &len2); | 429 memio_GetWriteParams(secret, &buf1, &len1, &buf2, &len2); |
| 331 int bytes_to_send = | 430 int bytes_to_send = |
| 332 dart::Utils::Minimum(len1, static_cast<unsigned>(bytes_free)); | 431 dart::Utils::Minimum(len1, static_cast<unsigned>(bytes_free)); |
| 333 if (bytes_to_send > 0) { | 432 if (bytes_to_send > 0) { |
| 334 memmove(buffer + start + length, buf1, bytes_to_send); | 433 memmove(buffer + start + length, buf1, bytes_to_send); |
| 335 bytes_processed = bytes_to_send; | 434 bytes_processed = bytes_to_send; |
| 336 } | 435 } |
| 337 bytes_to_send = dart::Utils::Minimum(len2, | 436 bytes_to_send = dart::Utils::Minimum(len2, |
| 338 static_cast<unsigned>(bytes_free - bytes_processed)); | 437 static_cast<unsigned>(bytes_free - bytes_processed)); |
| 339 if (bytes_to_send > 0) { | 438 if (bytes_to_send > 0) { |
| 340 memmove(buffer + start + length + bytes_processed, buf2, | 439 memmove(buffer + start + length + bytes_processed, buf2, |
| 341 bytes_to_send); | 440 bytes_to_send); |
| 342 bytes_processed += bytes_to_send; | 441 bytes_processed += bytes_to_send; |
| 343 } | 442 } |
| 344 if (bytes_processed > 0) { | 443 if (bytes_processed > 0) { |
| 345 memio_PutWriteResult(secret, bytes_processed); | 444 memio_PutWriteResult(secret, bytes_processed); |
| 346 } | 445 } |
| 347 break; | 446 break; |
| 348 } | 447 } |
| 349 | 448 |
| 350 case kReadEncrypted: { | 449 case kReadEncrypted: { |
| 351 if (length > 0) { | 450 if (length > 0) { |
| 352 bytes_processed = length; | 451 bytes_processed = length; |
| 353 memio_Private* secret = memio_GetSecret(memio_); | 452 memio_Private* secret = memio_GetSecret(filter_); |
| 354 uint8_t* memio_buf; | 453 uint8_t* filter_buf; |
| 355 int free_bytes = memio_GetReadParams(secret, &memio_buf); | 454 int free_bytes = memio_GetReadParams(secret, &filter_buf); |
| 356 if (free_bytes < bytes_processed) bytes_processed = free_bytes; | 455 if (free_bytes < bytes_processed) bytes_processed = free_bytes; |
| 357 memmove(memio_buf, | 456 memmove(filter_buf, |
| 358 buffer + start, | 457 buffer + start, |
| 359 bytes_processed); | 458 bytes_processed); |
| 360 memio_PutReadResult(secret, bytes_processed); | 459 memio_PutReadResult(secret, bytes_processed); |
| 361 } | 460 } |
| 362 break; | 461 break; |
| 363 } | 462 } |
| 364 | 463 |
| 365 case kWritePlaintext: { | 464 case kWritePlaintext: { |
| 366 if (length > 0) { | 465 if (length > 0) { |
| 367 bytes_processed = PR_Write(memio_, | 466 bytes_processed = PR_Write(filter_, |
| 368 buffer + start, | 467 buffer + start, |
| 369 length); | 468 length); |
| 370 } | 469 } |
| 371 | 470 |
| 372 if (bytes_processed < 0) { | 471 if (bytes_processed < 0) { |
| 373 ASSERT(bytes_processed == -1); | 472 ASSERT(bytes_processed == -1); |
| 374 // TODO(whesse): Handle unexpected errors here. | 473 // TODO(whesse): Handle unexpected errors here. |
| 375 PRErrorCode pr_error = PR_GetError(); | 474 PRErrorCode pr_error = PR_GetError(); |
| 376 if (PR_WOULD_BLOCK_ERROR != pr_error) { | 475 if (PR_WOULD_BLOCK_ERROR != pr_error) { |
| 377 ThrowPRException("Error reading plaintext from TlsFilter"); | 476 ThrowPRException("Error reading plaintext from TlsFilter"); |
| 378 } | 477 } |
| 379 bytes_processed = 0; | 478 bytes_processed = 0; |
| 380 } | 479 } |
| 381 break; | 480 break; |
| 382 } | 481 } |
| 383 } | 482 } |
| 384 return bytes_processed; | 483 return bytes_processed; |
| 385 } | 484 } |
| OLD | NEW |