OLD | NEW |
1 // Copyright (c) 2012, the Dart project authors. Please see the AUTHORS file | 1 // Copyright (c) 2012, the Dart project authors. Please see the AUTHORS file |
2 // for details. All rights reserved. Use of this source code is governed by a | 2 // for details. All rights reserved. Use of this source code is governed by a |
3 // BSD-style license that can be found in the LICENSE file. | 3 // BSD-style license that can be found in the LICENSE file. |
4 | 4 |
5 #include "bin/tls_socket.h" | 5 #include "bin/tls_socket.h" |
6 | 6 |
7 #include <errno.h> | 7 #include <errno.h> |
8 #include <fcntl.h> | 8 #include <fcntl.h> |
9 #include <sys/stat.h> | 9 #include <sys/stat.h> |
10 #include <stdio.h> | 10 #include <stdio.h> |
11 #include <string.h> | 11 #include <string.h> |
12 | 12 |
13 #include <nss.h> | 13 #include <nss.h> |
| 14 #include <pk11pub.h> |
14 #include <prerror.h> | 15 #include <prerror.h> |
15 #include <prinit.h> | 16 #include <prinit.h> |
16 #include <prnetdb.h> | 17 #include <prnetdb.h> |
17 #include <ssl.h> | 18 #include <ssl.h> |
| 19 #include <sslproto.h> |
18 | 20 |
19 #include "bin/builtin.h" | 21 #include "bin/builtin.h" |
20 #include "bin/dartutils.h" | 22 #include "bin/dartutils.h" |
21 #include "bin/net/nss_memio.h" | 23 #include "bin/net/nss_memio.h" |
22 #include "bin/thread.h" | 24 #include "bin/thread.h" |
23 #include "bin/utils.h" | 25 #include "bin/utils.h" |
24 #include "platform/utils.h" | 26 #include "platform/utils.h" |
25 | 27 |
26 #include "include/dart_api.h" | 28 #include "include/dart_api.h" |
27 | 29 |
28 bool TlsFilter::library_initialized_ = false; | 30 bool TlsFilter::library_initialized_ = false; |
29 dart::Mutex TlsFilter::mutex_; // To protect library initialization. | 31 dart::Mutex TlsFilter::mutex_; // To protect library initialization. |
| 32 // The password is needed when creating secure server sockets. It can |
| 33 // be null if only secure client sockets are used. |
| 34 const char* TlsFilter::password_ = NULL; |
| 35 |
30 static const int kTlsFilterNativeFieldIndex = 0; | 36 static const int kTlsFilterNativeFieldIndex = 0; |
31 | 37 |
32 static TlsFilter* GetTlsFilter(Dart_NativeArguments args) { | 38 static TlsFilter* GetTlsFilter(Dart_NativeArguments args) { |
33 TlsFilter* filter; | 39 TlsFilter* filter; |
34 Dart_Handle dart_this = ThrowIfError(Dart_GetNativeArgument(args, 0)); | 40 Dart_Handle dart_this = ThrowIfError(Dart_GetNativeArgument(args, 0)); |
35 ASSERT(Dart_IsInstance(dart_this)); | 41 ASSERT(Dart_IsInstance(dart_this)); |
36 ThrowIfError(Dart_GetNativeInstanceField( | 42 ThrowIfError(Dart_GetNativeInstanceField( |
37 dart_this, | 43 dart_this, |
38 kTlsFilterNativeFieldIndex, | 44 kTlsFilterNativeFieldIndex, |
39 reinterpret_cast<intptr_t*>(&filter))); | 45 reinterpret_cast<intptr_t*>(&filter))); |
(...skipping 16 matching lines...) Expand all Loading... |
56 Dart_Handle dart_this = ThrowIfError(Dart_GetNativeArgument(args, 0)); | 62 Dart_Handle dart_this = ThrowIfError(Dart_GetNativeArgument(args, 0)); |
57 TlsFilter* filter = new TlsFilter; | 63 TlsFilter* filter = new TlsFilter; |
58 SetTlsFilter(args, filter); | 64 SetTlsFilter(args, filter); |
59 filter->Init(dart_this); | 65 filter->Init(dart_this); |
60 Dart_ExitScope(); | 66 Dart_ExitScope(); |
61 } | 67 } |
62 | 68 |
63 | 69 |
64 void FUNCTION_NAME(TlsSocket_Connect)(Dart_NativeArguments args) { | 70 void FUNCTION_NAME(TlsSocket_Connect)(Dart_NativeArguments args) { |
65 Dart_EnterScope(); | 71 Dart_EnterScope(); |
66 Dart_Handle host_name = ThrowIfError(Dart_GetNativeArgument(args, 1)); | 72 Dart_Handle host_name_object = ThrowIfError(Dart_GetNativeArgument(args, 1)); |
67 Dart_Handle port_object = ThrowIfError(Dart_GetNativeArgument(args, 2)); | 73 Dart_Handle port_object = ThrowIfError(Dart_GetNativeArgument(args, 2)); |
| 74 Dart_Handle is_server_object = ThrowIfError(Dart_GetNativeArgument(args, 3)); |
| 75 Dart_Handle certificate_name_object = |
| 76 ThrowIfError(Dart_GetNativeArgument(args, 4)); |
68 | 77 |
69 const char* host_name_string = NULL; | 78 const char* host_name = NULL; |
70 // TODO(whesse): Is truncating a Dart string containing \0 what we want? | 79 // TODO(whesse): Is truncating a Dart string containing \0 what we want? |
71 ThrowIfError(Dart_StringToCString(host_name, &host_name_string)); | 80 ThrowIfError(Dart_StringToCString(host_name_object, &host_name)); |
72 | 81 |
73 int64_t port; | 82 int64_t port; |
74 if (!DartUtils::GetInt64Value(port_object, &port) || | 83 if (!DartUtils::GetInt64Value(port_object, &port) || |
75 port < 0 || port > 65535) { | 84 port < 0 || port > 65535) { |
76 Dart_ThrowException(DartUtils::NewDartArgumentError( | 85 Dart_ThrowException(DartUtils::NewDartArgumentError( |
77 "Illegal port parameter in TlsSocket")); | 86 "Illegal port parameter in _TlsFilter.connect")); |
78 } | 87 } |
79 | 88 |
80 GetTlsFilter(args)->Connect(host_name_string, static_cast<int>(port)); | 89 if (!Dart_IsBoolean(is_server_object)) { |
| 90 Dart_ThrowException(DartUtils::NewDartArgumentError( |
| 91 "Illegal is_server parameter in _TlsFilter.connect")); |
| 92 } |
| 93 bool is_server = DartUtils::GetBooleanValue(is_server_object); |
| 94 |
| 95 const char* certificate_name = NULL; |
| 96 // If this is a server connection, get the certificate to connect with. |
| 97 // TODO(whesse): Use this parameter for a client certificate as well. |
| 98 if (is_server) { |
| 99 if (!Dart_IsString(certificate_name_object)) { |
| 100 Dart_ThrowException(DartUtils::NewDartArgumentError( |
| 101 "Non-String certificate parameter in _TlsFilter.connect")); |
| 102 } |
| 103 ThrowIfError(Dart_StringToCString(certificate_name_object, |
| 104 &certificate_name)); |
| 105 } |
| 106 |
| 107 GetTlsFilter(args)->Connect(host_name, |
| 108 static_cast<int>(port), |
| 109 is_server, |
| 110 certificate_name); |
81 Dart_ExitScope(); | 111 Dart_ExitScope(); |
82 } | 112 } |
83 | 113 |
84 | 114 |
85 void FUNCTION_NAME(TlsSocket_Destroy)(Dart_NativeArguments args) { | 115 void FUNCTION_NAME(TlsSocket_Destroy)(Dart_NativeArguments args) { |
86 Dart_EnterScope(); | 116 Dart_EnterScope(); |
87 TlsFilter* filter = GetTlsFilter(args); | 117 TlsFilter* filter = GetTlsFilter(args); |
88 SetTlsFilter(args, NULL); | 118 SetTlsFilter(args, NULL); |
89 filter->Destroy(); | 119 filter->Destroy(); |
90 delete filter; | 120 delete filter; |
(...skipping 34 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
125 intptr_t bytes_read = | 155 intptr_t bytes_read = |
126 GetTlsFilter(args)->ProcessBuffer(static_cast<int>(buffer_id)); | 156 GetTlsFilter(args)->ProcessBuffer(static_cast<int>(buffer_id)); |
127 Dart_SetReturnValue(args, Dart_NewInteger(bytes_read)); | 157 Dart_SetReturnValue(args, Dart_NewInteger(bytes_read)); |
128 Dart_ExitScope(); | 158 Dart_ExitScope(); |
129 } | 159 } |
130 | 160 |
131 | 161 |
132 void FUNCTION_NAME(TlsSocket_SetCertificateDatabase) | 162 void FUNCTION_NAME(TlsSocket_SetCertificateDatabase) |
133 (Dart_NativeArguments args) { | 163 (Dart_NativeArguments args) { |
134 Dart_EnterScope(); | 164 Dart_EnterScope(); |
135 Dart_Handle dart_pkcert_dir = ThrowIfError(Dart_GetNativeArgument(args, 0)); | 165 Dart_Handle certificate_database_object = |
| 166 ThrowIfError(Dart_GetNativeArgument(args, 0)); |
136 // Check that the type is string, and get the UTF-8 C string value from it. | 167 // Check that the type is string, and get the UTF-8 C string value from it. |
137 if (Dart_IsString(dart_pkcert_dir)) { | 168 const char* certificate_database = NULL; |
138 const char* pkcert_dir = NULL; | 169 if (Dart_IsString(certificate_database_object)) { |
139 ThrowIfError(Dart_StringToCString(dart_pkcert_dir, &pkcert_dir)); | 170 ThrowIfError(Dart_StringToCString(certificate_database_object, |
140 TlsFilter::InitializeLibrary(pkcert_dir); | 171 &certificate_database)); |
141 } else { | 172 } else { |
142 Dart_ThrowException(DartUtils::NewDartArgumentError( | 173 Dart_ThrowException(DartUtils::NewDartArgumentError( |
143 "Non-String argument to SetCertificateDatabase")); | 174 "Non-String certificate directory argument to SetCertificateDatabase")); |
144 } | 175 } |
| 176 |
| 177 Dart_Handle password_object = ThrowIfError(Dart_GetNativeArgument(args, 1)); |
| 178 // Check that the type is string or null, |
| 179 // and get the UTF-8 C string value from it. |
| 180 const char* password = NULL; |
| 181 if (Dart_IsString(password_object)) { |
| 182 ThrowIfError(Dart_StringToCString(password_object, &password)); |
| 183 } else if (Dart_IsNull(password_object)) { |
| 184 // Pass the empty string as the password. |
| 185 password = ""; |
| 186 } else { |
| 187 Dart_ThrowException(DartUtils::NewDartArgumentError( |
| 188 "Password argument to SetCertificateDatabase is not a String or null")); |
| 189 } |
| 190 |
| 191 TlsFilter::InitializeLibrary(certificate_database, password); |
145 Dart_ExitScope(); | 192 Dart_ExitScope(); |
146 } | 193 } |
147 | 194 |
148 | 195 |
149 void TlsFilter::Init(Dart_Handle dart_this) { | 196 void TlsFilter::Init(Dart_Handle dart_this) { |
150 string_start_ = ThrowIfError( | 197 string_start_ = ThrowIfError( |
151 Dart_NewPersistentHandle(DartUtils::NewString("start"))); | 198 Dart_NewPersistentHandle(DartUtils::NewString("start"))); |
152 string_length_ = ThrowIfError( | 199 string_length_ = ThrowIfError( |
153 Dart_NewPersistentHandle(DartUtils::NewString("length"))); | 200 Dart_NewPersistentHandle(DartUtils::NewString("length"))); |
154 | 201 |
155 InitializeBuffers(dart_this); | 202 InitializeBuffers(dart_this); |
156 memio_ = memio_CreateIOLayer(kMemioBufferSize); | 203 filter_ = memio_CreateIOLayer(kMemioBufferSize); |
157 } | 204 } |
158 | 205 |
159 | 206 |
160 void TlsFilter::InitializeBuffers(Dart_Handle dart_this) { | 207 void TlsFilter::InitializeBuffers(Dart_Handle dart_this) { |
161 // Create TlsFilter buffers as ExternalUint8Array objects. | 208 // Create TlsFilter buffers as ExternalUint8Array objects. |
162 Dart_Handle dart_buffers_object = ThrowIfError( | 209 Dart_Handle dart_buffers_object = ThrowIfError( |
163 Dart_GetField(dart_this, DartUtils::NewString("buffers"))); | 210 Dart_GetField(dart_this, DartUtils::NewString("buffers"))); |
164 Dart_Handle dart_buffer_object = | 211 Dart_Handle dart_buffer_object = |
165 Dart_ListGetAt(dart_buffers_object, kReadPlaintext); | 212 Dart_ListGetAt(dart_buffers_object, kReadPlaintext); |
166 Dart_Handle tls_external_buffer_class = | 213 Dart_Handle tls_external_buffer_class = |
(...skipping 19 matching lines...) Expand all Loading... |
186 } | 233 } |
187 } | 234 } |
188 | 235 |
189 | 236 |
190 void TlsFilter::RegisterHandshakeCompleteCallback(Dart_Handle complete) { | 237 void TlsFilter::RegisterHandshakeCompleteCallback(Dart_Handle complete) { |
191 ASSERT(NULL == handshake_complete_); | 238 ASSERT(NULL == handshake_complete_); |
192 handshake_complete_ = ThrowIfError(Dart_NewPersistentHandle(complete)); | 239 handshake_complete_ = ThrowIfError(Dart_NewPersistentHandle(complete)); |
193 } | 240 } |
194 | 241 |
195 | 242 |
196 void TlsFilter::InitializeLibrary(const char* pkcert_database) { | 243 void TlsFilter::InitializeLibrary(const char* certificate_database, |
| 244 const char* password) { |
197 MutexLocker locker(&mutex_); | 245 MutexLocker locker(&mutex_); |
198 if (!library_initialized_) { | 246 if (!library_initialized_) { |
| 247 library_initialized_ = true; |
| 248 password_ = strdup(password); // This one copy persists until Dart exits. |
199 PR_Init(PR_USER_THREAD, PR_PRIORITY_NORMAL, 0); | 249 PR_Init(PR_USER_THREAD, PR_PRIORITY_NORMAL, 0); |
200 // TODO(whesse): Verify there are no UTF-8 issues here. | 250 // TODO(whesse): Verify there are no UTF-8 issues here. |
201 SECStatus status = NSS_Init(pkcert_database); | 251 SECStatus status = NSS_Init(certificate_database); |
202 if (status != SECSuccess) { | 252 if (status != SECSuccess) { |
203 ThrowPRException("Unsuccessful NSS_Init call."); | 253 ThrowPRException("Unsuccessful NSS_Init call."); |
204 } | 254 } |
205 | 255 |
206 status = NSS_SetDomesticPolicy(); | 256 status = NSS_SetDomesticPolicy(); |
207 if (status != SECSuccess) { | 257 if (status != SECSuccess) { |
208 ThrowPRException("Unsuccessful NSS_SetDomesticPolicy call."); | 258 ThrowPRException("Unsuccessful NSS_SetDomesticPolicy call."); |
209 } | 259 } |
| 260 // Enable TLS, as well as SSL3 and SSL2. |
| 261 status = SSL_OptionSetDefault(SSL_ENABLE_TLS, PR_TRUE); |
| 262 if (status != SECSuccess) { |
| 263 ThrowPRException("Unsuccessful SSL_OptionSetDefault enable TLS call."); |
| 264 } |
210 } else { | 265 } else { |
211 ThrowException("Called TlsFilter::InitializeLibrary more than once"); | 266 ThrowException("Called TlsFilter::InitializeLibrary more than once"); |
212 } | 267 } |
213 } | 268 } |
214 | 269 |
| 270 char* PasswordCallback(PK11SlotInfo* slot, PRBool retry, void* arg) { |
| 271 if (!retry) { |
| 272 return PL_strdup(static_cast<char*>(arg)); // Freed by NSS internals. |
| 273 } |
| 274 return NULL; |
| 275 } |
215 | 276 |
216 void TlsFilter::Connect(const char* host, int port) { | 277 void TlsFilter::Connect(const char* host_name, |
| 278 int port, |
| 279 bool is_server, |
| 280 const char* certificate_name) { |
| 281 is_server_ = is_server; |
217 if (in_handshake_) { | 282 if (in_handshake_) { |
218 ThrowException("Connect called while already in handshake state."); | 283 ThrowException("Connect called while already in handshake state."); |
219 } | 284 } |
220 PRFileDesc* my_socket = memio_; | |
221 | 285 |
222 my_socket = SSL_ImportFD(NULL, my_socket); | 286 filter_ = SSL_ImportFD(NULL, filter_); |
223 if (my_socket == NULL) { | 287 if (filter_ == NULL) { |
224 ThrowPRException("Unsuccessful SSL_ImportFD call"); | 288 ThrowPRException("Unsuccessful SSL_ImportFD call"); |
225 } | 289 } |
226 | 290 |
227 if (SSL_SetURL(my_socket, host) == -1) { | 291 SECStatus status; |
228 ThrowPRException("Unsuccessful SetURL call"); | 292 if (is_server) { |
| 293 PK11_SetPasswordFunc(PasswordCallback); |
| 294 CERTCertDBHandle* certificate_database = CERT_GetDefaultCertDB(); |
| 295 if (certificate_database == NULL) { |
| 296 ThrowPRException("Certificate database cannot be loaded"); |
| 297 } |
| 298 CERTCertificate* certificate = CERT_FindCertByNameString( |
| 299 certificate_database, |
| 300 const_cast<char*>(certificate_name)); |
| 301 if (certificate == NULL) { |
| 302 ThrowPRException("Cannot find server certificate by name"); |
| 303 } |
| 304 SECKEYPrivateKey* key = PK11_FindKeyByAnyCert( |
| 305 certificate, |
| 306 static_cast<void*>(const_cast<char*>(password_))); |
| 307 if (key == NULL) { |
| 308 if (PR_GetError() == -8177) { |
| 309 ThrowPRException("Certificate database password incorrect"); |
| 310 } else { |
| 311 ThrowPRException("Unsuccessful PK11_FindKeyByAnyCert call." |
| 312 " Cannot find private key for certificate"); |
| 313 } |
| 314 } |
| 315 // kt_rsa (key type RSA) is an enum constant from the NSS libraries. |
| 316 // TODO(whesse): Allow different key types. |
| 317 status = SSL_ConfigSecureServer(filter_, certificate, key, kt_rsa); |
| 318 if (status != SECSuccess) { |
| 319 ThrowPRException("Unsuccessful SSL_ConfigSecureServer call"); |
| 320 } |
| 321 } else { // Client. |
| 322 if (SSL_SetURL(filter_, host_name) == -1) { |
| 323 ThrowPRException("Unsuccessful SetURL call"); |
| 324 } |
229 } | 325 } |
230 | 326 |
231 SECStatus status = SSL_ResetHandshake(my_socket, PR_FALSE); | 327 PRBool as_server = is_server ? PR_TRUE : PR_FALSE; // Convert bool to PRBool. |
| 328 status = SSL_ResetHandshake(filter_, as_server); |
232 if (status != SECSuccess) { | 329 if (status != SECSuccess) { |
233 ThrowPRException("Unsuccessful SSL_ResetHandshake call"); | 330 ThrowPRException("Unsuccessful SSL_ResetHandshake call"); |
234 } | 331 } |
235 | 332 |
236 // SetPeerAddress | 333 // SetPeerAddress |
237 PRNetAddr host_address; | 334 PRNetAddr host_address; |
238 char host_entry_buffer[PR_NETDB_BUF_SIZE]; | 335 char host_entry_buffer[PR_NETDB_BUF_SIZE]; |
239 PRHostEnt host_entry; | 336 PRHostEnt host_entry; |
240 PRStatus rv = PR_GetHostByName(host, host_entry_buffer, | 337 PRStatus rv = PR_GetHostByName(host_name, host_entry_buffer, |
241 PR_NETDB_BUF_SIZE, &host_entry); | 338 PR_NETDB_BUF_SIZE, &host_entry); |
242 if (rv != PR_SUCCESS) { | 339 if (rv != PR_SUCCESS) { |
243 ThrowPRException("Unsuccessful PR_GetHostByName call"); | 340 ThrowPRException("Unsuccessful PR_GetHostByName call"); |
244 } | 341 } |
245 | 342 |
246 int index = PR_EnumerateHostEnt(0, &host_entry, port, &host_address); | 343 int index = PR_EnumerateHostEnt(0, &host_entry, port, &host_address); |
247 if (index == -1 || index == 0) { | 344 if (index == -1 || index == 0) { |
248 ThrowPRException("Unsuccessful PR_EnumerateHostEnt call"); | 345 ThrowPRException("Unsuccessful PR_EnumerateHostEnt call"); |
249 } | 346 } |
250 | 347 memio_SetPeerName(filter_, &host_address); |
251 memio_SetPeerName(my_socket, &host_address); | |
252 memio_ = my_socket; | |
253 } | 348 } |
254 | 349 |
255 | 350 |
256 void TlsFilter::Handshake() { | 351 void TlsFilter::Handshake() { |
257 SECStatus status = SSL_ForceHandshake(memio_); | 352 SECStatus status = SSL_ForceHandshake(filter_); |
258 if (status == SECSuccess) { | 353 if (status == SECSuccess) { |
259 if (in_handshake_) { | 354 if (in_handshake_) { |
260 ThrowIfError(Dart_InvokeClosure(handshake_complete_, 0, NULL)); | 355 ThrowIfError(Dart_InvokeClosure(handshake_complete_, 0, NULL)); |
261 in_handshake_ = false; | 356 in_handshake_ = false; |
262 } | 357 } |
263 } else { | 358 } else { |
264 PRErrorCode error = PR_GetError(); | 359 PRErrorCode error = PR_GetError(); |
265 if (error == PR_WOULD_BLOCK_ERROR) { | 360 if (error == PR_WOULD_BLOCK_ERROR) { |
266 if (!in_handshake_) { | 361 if (!in_handshake_) { |
267 in_handshake_ = true; | 362 in_handshake_ = true; |
268 } | 363 } |
269 } else { | 364 } else { |
270 ThrowPRException("Unexpected handshake error"); | 365 if (is_server_) { |
| 366 ThrowPRException("Unexpected handshake error in server"); |
| 367 } else { |
| 368 ThrowPRException("Unexpected handshake error in client"); |
| 369 } |
271 } | 370 } |
272 } | 371 } |
273 } | 372 } |
274 | 373 |
275 | 374 |
276 void TlsFilter::Destroy() { | 375 void TlsFilter::Destroy() { |
277 for (int i = 0; i < kNumBuffers; ++i) { | 376 for (int i = 0; i < kNumBuffers; ++i) { |
278 Dart_DeletePersistentHandle(dart_buffer_objects_[i]); | 377 Dart_DeletePersistentHandle(dart_buffer_objects_[i]); |
279 delete[] buffers_[i]; | 378 delete[] buffers_[i]; |
280 } | 379 } |
(...skipping 17 matching lines...) Expand all Loading... |
298 ASSERT(unsafe_length >= 0); | 397 ASSERT(unsafe_length >= 0); |
299 ASSERT(unsafe_length <= buffer_size_); | 398 ASSERT(unsafe_length <= buffer_size_); |
300 intptr_t start = static_cast<intptr_t>(unsafe_start); | 399 intptr_t start = static_cast<intptr_t>(unsafe_start); |
301 intptr_t length = static_cast<intptr_t>(unsafe_length); | 400 intptr_t length = static_cast<intptr_t>(unsafe_length); |
302 uint8_t* buffer = buffers_[buffer_index]; | 401 uint8_t* buffer = buffers_[buffer_index]; |
303 | 402 |
304 int bytes_processed = 0; | 403 int bytes_processed = 0; |
305 switch (buffer_index) { | 404 switch (buffer_index) { |
306 case kReadPlaintext: { | 405 case kReadPlaintext: { |
307 int bytes_free = buffer_size_ - start - length; | 406 int bytes_free = buffer_size_ - start - length; |
308 bytes_processed = PR_Read(memio_, | 407 bytes_processed = PR_Read(filter_, |
309 buffer + start + length, | 408 buffer + start + length, |
310 bytes_free); | 409 bytes_free); |
311 if (bytes_processed < 0) { | 410 if (bytes_processed < 0) { |
312 ASSERT(bytes_processed == -1); | 411 ASSERT(bytes_processed == -1); |
313 // TODO(whesse): Handle unexpected errors here. | 412 // TODO(whesse): Handle unexpected errors here. |
314 PRErrorCode pr_error = PR_GetError(); | 413 PRErrorCode pr_error = PR_GetError(); |
315 if (PR_WOULD_BLOCK_ERROR != pr_error) { | 414 if (PR_WOULD_BLOCK_ERROR != pr_error) { |
316 ThrowPRException("Error reading plaintext from TlsFilter"); | 415 ThrowPRException("Error reading plaintext from TlsFilter"); |
317 } | 416 } |
318 bytes_processed = 0; | 417 bytes_processed = 0; |
319 } | 418 } |
320 break; | 419 break; |
321 } | 420 } |
322 | 421 |
323 case kWriteEncrypted: { | 422 case kWriteEncrypted: { |
324 const uint8_t* buf1; | 423 const uint8_t* buf1; |
325 const uint8_t* buf2; | 424 const uint8_t* buf2; |
326 unsigned int len1; | 425 unsigned int len1; |
327 unsigned int len2; | 426 unsigned int len2; |
328 int bytes_free = buffer_size_ - start - length; | 427 int bytes_free = buffer_size_ - start - length; |
329 memio_Private* secret = memio_GetSecret(memio_); | 428 memio_Private* secret = memio_GetSecret(filter_); |
330 memio_GetWriteParams(secret, &buf1, &len1, &buf2, &len2); | 429 memio_GetWriteParams(secret, &buf1, &len1, &buf2, &len2); |
331 int bytes_to_send = | 430 int bytes_to_send = |
332 dart::Utils::Minimum(len1, static_cast<unsigned>(bytes_free)); | 431 dart::Utils::Minimum(len1, static_cast<unsigned>(bytes_free)); |
333 if (bytes_to_send > 0) { | 432 if (bytes_to_send > 0) { |
334 memmove(buffer + start + length, buf1, bytes_to_send); | 433 memmove(buffer + start + length, buf1, bytes_to_send); |
335 bytes_processed = bytes_to_send; | 434 bytes_processed = bytes_to_send; |
336 } | 435 } |
337 bytes_to_send = dart::Utils::Minimum(len2, | 436 bytes_to_send = dart::Utils::Minimum(len2, |
338 static_cast<unsigned>(bytes_free - bytes_processed)); | 437 static_cast<unsigned>(bytes_free - bytes_processed)); |
339 if (bytes_to_send > 0) { | 438 if (bytes_to_send > 0) { |
340 memmove(buffer + start + length + bytes_processed, buf2, | 439 memmove(buffer + start + length + bytes_processed, buf2, |
341 bytes_to_send); | 440 bytes_to_send); |
342 bytes_processed += bytes_to_send; | 441 bytes_processed += bytes_to_send; |
343 } | 442 } |
344 if (bytes_processed > 0) { | 443 if (bytes_processed > 0) { |
345 memio_PutWriteResult(secret, bytes_processed); | 444 memio_PutWriteResult(secret, bytes_processed); |
346 } | 445 } |
347 break; | 446 break; |
348 } | 447 } |
349 | 448 |
350 case kReadEncrypted: { | 449 case kReadEncrypted: { |
351 if (length > 0) { | 450 if (length > 0) { |
352 bytes_processed = length; | 451 bytes_processed = length; |
353 memio_Private* secret = memio_GetSecret(memio_); | 452 memio_Private* secret = memio_GetSecret(filter_); |
354 uint8_t* memio_buf; | 453 uint8_t* filter_buf; |
355 int free_bytes = memio_GetReadParams(secret, &memio_buf); | 454 int free_bytes = memio_GetReadParams(secret, &filter_buf); |
356 if (free_bytes < bytes_processed) bytes_processed = free_bytes; | 455 if (free_bytes < bytes_processed) bytes_processed = free_bytes; |
357 memmove(memio_buf, | 456 memmove(filter_buf, |
358 buffer + start, | 457 buffer + start, |
359 bytes_processed); | 458 bytes_processed); |
360 memio_PutReadResult(secret, bytes_processed); | 459 memio_PutReadResult(secret, bytes_processed); |
361 } | 460 } |
362 break; | 461 break; |
363 } | 462 } |
364 | 463 |
365 case kWritePlaintext: { | 464 case kWritePlaintext: { |
366 if (length > 0) { | 465 if (length > 0) { |
367 bytes_processed = PR_Write(memio_, | 466 bytes_processed = PR_Write(filter_, |
368 buffer + start, | 467 buffer + start, |
369 length); | 468 length); |
370 } | 469 } |
371 | 470 |
372 if (bytes_processed < 0) { | 471 if (bytes_processed < 0) { |
373 ASSERT(bytes_processed == -1); | 472 ASSERT(bytes_processed == -1); |
374 // TODO(whesse): Handle unexpected errors here. | 473 // TODO(whesse): Handle unexpected errors here. |
375 PRErrorCode pr_error = PR_GetError(); | 474 PRErrorCode pr_error = PR_GetError(); |
376 if (PR_WOULD_BLOCK_ERROR != pr_error) { | 475 if (PR_WOULD_BLOCK_ERROR != pr_error) { |
377 ThrowPRException("Error reading plaintext from TlsFilter"); | 476 ThrowPRException("Error reading plaintext from TlsFilter"); |
378 } | 477 } |
379 bytes_processed = 0; | 478 bytes_processed = 0; |
380 } | 479 } |
381 break; | 480 break; |
382 } | 481 } |
383 } | 482 } |
384 return bytes_processed; | 483 return bytes_processed; |
385 } | 484 } |
OLD | NEW |