Index: net/socket/ssl_client_socket_nss.cc |
diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc |
index 05549da933c6b8fb882b612c003514bc9340411e..ca0485714937356bf64695d2bc743a0f9cadd2ca 100644 |
--- a/net/socket/ssl_client_socket_nss.cc |
+++ b/net/socket/ssl_client_socket_nss.cc |
@@ -81,15 +81,18 @@ namespace net { |
#if 1 |
#define EnterFunction(x) |
#define LeaveFunction(x) |
-#define GotoState(s) next_state_ = s |
+#define GotoState(s) next_handshake_state_ = s |
#define LogData(s, len) |
#else |
#define EnterFunction(x) LOG(INFO) << (void *)this << " " << __FUNCTION__ << \ |
- " enter " << x << "; next_state " << next_state_ |
+ " enter " << x << \ |
+ "; next_handshake_state " << next_handshake_state_ |
#define LeaveFunction(x) LOG(INFO) << (void *)this << " " << __FUNCTION__ << \ |
- " leave " << x << "; next_state " << next_state_ |
+ " leave " << x << \ |
+ "; next_handshake_state " << next_handshake_state_ |
#define GotoState(s) do { LOG(INFO) << (void *)this << " " << __FUNCTION__ << \ |
- " jump to state " << s; next_state_ = s; } while (0) |
+ " jump to state " << s; \ |
+ next_handshake_state_ = s; } while (0) |
#define LogData(s, len) LOG(INFO) << (void *)this << " " << __FUNCTION__ << \ |
" data [" << std::string(s, len) << "]"; |
@@ -193,15 +196,17 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocket* transport_socket, |
buffer_recv_callback_(this, &SSLClientSocketNSS::BufferRecvComplete), |
transport_send_busy_(false), |
transport_recv_busy_(false), |
- io_callback_(this, &SSLClientSocketNSS::OnIOComplete), |
+ handshake_io_callback_(this, &SSLClientSocketNSS::OnHandshakeIOComplete), |
transport_(transport_socket), |
hostname_(hostname), |
ssl_config_(ssl_config), |
user_connect_callback_(NULL), |
- user_callback_(NULL), |
- user_buf_len_(0), |
+ user_read_callback_(NULL), |
+ user_write_callback_(NULL), |
+ user_read_buf_len_(0), |
+ user_write_buf_len_(0), |
completed_handshake_(false), |
- next_state_(STATE_NONE), |
+ next_handshake_state_(STATE_NONE), |
nss_fd_(NULL), |
nss_bufs_(NULL) { |
EnterFunction(""); |
@@ -229,10 +234,12 @@ int SSLClientSocketNSS::Init() { |
int SSLClientSocketNSS::Connect(CompletionCallback* callback) { |
EnterFunction(""); |
DCHECK(transport_.get()); |
- DCHECK(next_state_ == STATE_NONE); |
- DCHECK(!user_callback_); |
+ DCHECK(next_handshake_state_ == STATE_NONE); |
+ DCHECK(!user_read_callback_); |
+ DCHECK(!user_write_callback_); |
DCHECK(!user_connect_callback_); |
- DCHECK(!user_buf_); |
+ DCHECK(!user_read_buf_); |
+ DCHECK(!user_write_buf_); |
if (Init() != OK) { |
NOTREACHED() << "Couldn't initialize nss"; |
@@ -321,8 +328,8 @@ int SSLClientSocketNSS::Connect(CompletionCallback* callback) { |
// Tell SSL we're a client; needed if not letting NSPR do socket I/O |
SSL_ResetHandshake(nss_fd_, 0); |
- GotoState(STATE_HANDSHAKE_READ); |
- rv = DoLoop(OK); |
+ GotoState(STATE_HANDSHAKE); |
+ rv = DoHandshakeLoop(OK); |
if (rv == ERR_IO_PENDING) |
user_connect_callback_ = callback; |
@@ -348,7 +355,7 @@ void SSLClientSocketNSS::Disconnect() { |
} |
// Shut down anything that may call us back (through buffer_send_callback_, |
- // buffer_recv_callback, or io_callback_). |
+ // buffer_recv_callback, or handshake_io_callback_). |
verifier_.reset(); |
transport_->Disconnect(); |
@@ -356,9 +363,12 @@ void SSLClientSocketNSS::Disconnect() { |
transport_send_busy_ = false; |
transport_recv_busy_ = false; |
user_connect_callback_ = NULL; |
- user_callback_ = NULL; |
- user_buf_ = NULL; |
- user_buf_len_ = 0; |
+ user_read_callback_ = NULL; |
+ user_write_callback_ = NULL; |
+ user_read_buf_ = NULL; |
+ user_read_buf_len_ = 0; |
+ user_write_buf_ = NULL; |
+ user_write_buf_len_ = 0; |
server_cert_ = NULL; |
server_cert_verify_result_.Reset(); |
completed_handshake_ = false; |
@@ -398,38 +408,48 @@ int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, |
CompletionCallback* callback) { |
EnterFunction(buf_len); |
DCHECK(completed_handshake_); |
- DCHECK(next_state_ == STATE_NONE); |
- DCHECK(!user_callback_); |
+ DCHECK(next_handshake_state_ == STATE_NONE); |
+ DCHECK(!user_read_callback_); |
DCHECK(!user_connect_callback_); |
- DCHECK(!user_buf_); |
+ DCHECK(!user_read_buf_); |
+ DCHECK(nss_bufs_); |
- user_buf_ = buf; |
- user_buf_len_ = buf_len; |
+ user_read_buf_ = buf; |
+ user_read_buf_len_ = buf_len; |
+ |
+ int rv = DoReadLoop(OK); |
- GotoState(STATE_PAYLOAD_READ); |
- int rv = DoLoop(OK); |
if (rv == ERR_IO_PENDING) |
wtc
2009/10/14 01:54:15
Please add braces {} to "if" when "else" has brace
|
- user_callback_ = callback; |
+ user_read_callback_ = callback; |
+ else { |
+ user_read_buf_ = NULL; |
+ user_read_buf_len_ = 0; |
+ } |
LeaveFunction(rv); |
return rv; |
} |
int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, |
- CompletionCallback* callback) { |
+ CompletionCallback* callback) { |
EnterFunction(buf_len); |
DCHECK(completed_handshake_); |
- DCHECK(next_state_ == STATE_NONE); |
- DCHECK(!user_callback_); |
+ DCHECK(next_handshake_state_ == STATE_NONE); |
+ DCHECK(!user_write_callback_); |
DCHECK(!user_connect_callback_); |
- DCHECK(!user_buf_); |
+ DCHECK(!user_write_buf_); |
+ DCHECK(nss_bufs_); |
+ |
+ user_write_buf_ = buf; |
+ user_write_buf_len_ = buf_len; |
- user_buf_ = buf; |
- user_buf_len_ = buf_len; |
+ int rv = DoWriteLoop(OK); |
- GotoState(STATE_PAYLOAD_WRITE); |
- int rv = DoLoop(OK); |
if (rv == ERR_IO_PENDING) |
- user_callback_ = callback; |
+ user_write_callback_ = callback; |
wtc
2009/10/14 01:54:15
Same here: both "if" and "else" should have braces
|
+ else { |
+ user_write_buf_ = NULL; |
+ user_write_buf_len_ = 0; |
+ } |
LeaveFunction(rv); |
return rv; |
} |
@@ -491,15 +511,32 @@ void SSLClientSocketNSS::GetSSLCertRequestInfo( |
// TODO(wtc): implement this. |
} |
-void SSLClientSocketNSS::DoCallback(int rv) { |
+void SSLClientSocketNSS::DoReadCallback(int rv) { |
EnterFunction(rv); |
DCHECK(rv != ERR_IO_PENDING); |
- DCHECK(user_callback_); |
+ DCHECK(user_read_callback_); |
- // Since Run may result in Read being called, clear |user_callback_| up front. |
- CompletionCallback* c = user_callback_; |
- user_callback_ = NULL; |
- user_buf_ = NULL; |
+ // Since Run may result in Read being called, clear |user_read_callback_| |
+ // up front. |
+ CompletionCallback* c = user_read_callback_; |
+ user_read_callback_ = NULL; |
+ user_read_buf_ = NULL; |
+ user_read_buf_len_ = 0; |
+ c->Run(rv); |
+ LeaveFunction(""); |
+} |
+ |
+void SSLClientSocketNSS::DoWriteCallback(int rv) { |
+ EnterFunction(rv); |
+ DCHECK(rv != ERR_IO_PENDING); |
+ DCHECK(user_write_callback_); |
+ |
+ // Since Run may result in Write being called, clear |user_write_callback_| |
+ // up front. |
+ CompletionCallback* c = user_write_callback_; |
+ user_write_callback_ = NULL; |
+ user_write_buf_ = NULL; |
+ user_write_buf_len_ = 0; |
c->Run(rv); |
LeaveFunction(""); |
} |
@@ -516,24 +553,71 @@ void SSLClientSocketNSS::DoConnectCallback(int rv) { |
DCHECK_NE(rv, ERR_IO_PENDING); |
DCHECK(user_connect_callback_); |
- // Since Run may result in Read being called, clear |user_connect_callback_| |
- // up front. |
CompletionCallback* c = user_connect_callback_; |
user_connect_callback_ = NULL; |
c->Run(rv > OK ? OK : rv); |
LeaveFunction(""); |
} |
-void SSLClientSocketNSS::OnIOComplete(int result) { |
+void SSLClientSocketNSS::OnHandshakeIOComplete(int result) { |
EnterFunction(result); |
- int rv = DoLoop(result); |
- if (rv != ERR_IO_PENDING) { |
- if (user_callback_) { |
- DoCallback(rv); |
- } else if (user_connect_callback_) { |
- DoConnectCallback(rv); |
- } |
+ int rv = DoHandshakeLoop(result); |
+ if (rv != ERR_IO_PENDING) |
+ DoConnectCallback(rv); |
+ LeaveFunction(""); |
+} |
+ |
+void SSLClientSocketNSS::OnSendComplete(int result) { |
+ EnterFunction(result); |
+ if (next_handshake_state_ != STATE_NONE) { |
+ // In handshake phase. |
+ OnHandshakeIOComplete(result); |
+ LeaveFunction(""); |
+ return; |
+ } |
+ |
+ // OnSendComplete may need to call DoPayloadRead while the renegotiation |
+ // handshake is in progress. |
+ int rv_read = ERR_IO_PENDING; |
+ int rv_write = ERR_IO_PENDING; |
+ bool network_moved; |
+ do { |
+ if (user_read_buf_) |
+ rv_read = DoPayloadRead(); |
+ if (user_write_buf_) |
+ rv_write = DoPayloadWrite(); |
+ network_moved = DoTransportIO(); |
+ } while (rv_read == ERR_IO_PENDING && |
+ rv_write == ERR_IO_PENDING && |
+ network_moved); |
+ |
+ if (user_read_buf_ && rv_read != ERR_IO_PENDING) |
+ DoReadCallback(rv_read); |
+ if (user_write_buf_ && rv_write != ERR_IO_PENDING) |
+ DoWriteCallback(rv_write); |
+ |
+ LeaveFunction(""); |
+} |
+ |
+void SSLClientSocketNSS::OnRecvComplete(int result) { |
+ EnterFunction(result); |
+ if (next_handshake_state_ != STATE_NONE) { |
+ // In handshake phase. |
+ OnHandshakeIOComplete(result); |
+ LeaveFunction(""); |
+ return; |
} |
+ |
+ // Network layer received some data, check if client requested to read |
+ // decrypted data. |
+ if (!user_read_buf_) { |
+ LeaveFunction(""); |
+ return; |
+ } |
+ |
+ int rv = DoReadLoop(result); |
+ if (rv != ERR_IO_PENDING) |
+ DoReadCallback(rv); |
LeaveFunction(""); |
} |
@@ -549,6 +633,19 @@ static PRErrorCode MapErrorToNSS(int result) { |
} |
// Do network I/O between the given buffer and the given socket. |
+// Return true if some I/O performed, false otherwise (error or ERR_IO_PENDING) |
+bool SSLClientSocketNSS::DoTransportIO() { |
+ EnterFunction(""); |
+ bool network_moved = false; |
+ if (nss_bufs_ != NULL) { |
+ int nsent = BufferSend(); |
+ int nreceived = BufferRecv(); |
+ network_moved = (nsent > 0 || nreceived >= 0); |
+ } |
+ LeaveFunction(network_moved); |
+ return network_moved; |
+} |
+ |
// Return 0 for EOF, |
// > 0 for bytes transferred immediately, |
// < 0 for error (or the non-error ERR_IO_PENDING). |
@@ -580,7 +677,7 @@ void SSLClientSocketNSS::BufferSendComplete(int result) { |
EnterFunction(result); |
memio_PutWriteResult(nss_bufs_, result); |
transport_send_busy_ = false; |
- OnIOComplete(result); |
+ OnSendComplete(result); |
LeaveFunction(""); |
} |
@@ -621,29 +718,28 @@ void SSLClientSocketNSS::BufferRecvComplete(int result) { |
recv_buffer_ = NULL; |
memio_PutReadResult(nss_bufs_, result); |
transport_recv_busy_ = false; |
- OnIOComplete(result); |
+ OnRecvComplete(result); |
LeaveFunction(""); |
} |
-int SSLClientSocketNSS::DoLoop(int last_io_result) { |
+int SSLClientSocketNSS::DoHandshakeLoop(int last_io_result) { |
EnterFunction(last_io_result); |
bool network_moved; |
int rv = last_io_result; |
do { |
- network_moved = false; |
// Default to STATE_NONE for next state. |
// (This is a quirk carried over from the windows |
// implementation. It makes reading the logs a bit harder.) |
// State handlers can and often do call GotoState just |
// to stay in the current state. |
- State state = next_state_; |
+ State state = next_handshake_state_; |
GotoState(STATE_NONE); |
switch (state) { |
case STATE_NONE: |
// we're just pumping data between the buffer and the network |
break; |
- case STATE_HANDSHAKE_READ: |
- rv = DoHandshakeRead(); |
+ case STATE_HANDSHAKE: |
+ rv = DoHandshake(); |
break; |
case STATE_VERIFY_CERT: |
DCHECK(rv == OK); |
@@ -652,12 +748,6 @@ int SSLClientSocketNSS::DoLoop(int last_io_result) { |
case STATE_VERIFY_CERT_COMPLETE: |
rv = DoVerifyCertComplete(rv); |
break; |
- case STATE_PAYLOAD_READ: |
- rv = DoPayloadRead(); |
- break; |
- case STATE_PAYLOAD_WRITE: |
- rv = DoPayloadWrite(); |
- break; |
default: |
rv = ERR_UNEXPECTED; |
NOTREACHED() << "unexpected state"; |
@@ -665,13 +755,53 @@ int SSLClientSocketNSS::DoLoop(int last_io_result) { |
} |
// Do the actual network I/O |
- if (nss_bufs_ != NULL) { |
- int nsent = BufferSend(); |
- int nreceived = BufferRecv(); |
- network_moved = (nsent > 0 || nreceived >= 0); |
- } |
+ network_moved = DoTransportIO(); |
} while ((rv != ERR_IO_PENDING || network_moved) && |
- next_state_ != STATE_NONE); |
+ next_handshake_state_ != STATE_NONE); |
+ LeaveFunction(""); |
+ return rv; |
+} |
+ |
+int SSLClientSocketNSS::DoReadLoop(int result) { |
+ EnterFunction(""); |
+ DCHECK(completed_handshake_); |
+ DCHECK(next_handshake_state_ == STATE_NONE); |
+ |
+ if (result < 0) |
+ return result; |
+ |
+ if (!nss_bufs_) |
+ return ERR_UNEXPECTED; |
+ |
+ bool network_moved; |
+ int rv; |
+ do { |
+ rv = DoPayloadRead(); |
+ network_moved = DoTransportIO(); |
+ } while (rv == ERR_IO_PENDING && network_moved); |
+ |
+ LeaveFunction(""); |
+ return rv; |
+} |
+ |
+int SSLClientSocketNSS::DoWriteLoop(int result) { |
+ EnterFunction(""); |
+ DCHECK(completed_handshake_); |
+ DCHECK(next_handshake_state_ == STATE_NONE); |
+ |
+ if (result < 0) |
+ return result; |
+ |
+ if (!nss_bufs_) |
+ return ERR_UNEXPECTED; |
+ |
+ bool network_moved; |
+ int rv; |
+ do { |
+ rv = DoPayloadWrite(); |
+ network_moved = DoTransportIO(); |
+ } while (rv == ERR_IO_PENDING && network_moved); |
+ |
LeaveFunction(""); |
return rv; |
} |
@@ -701,7 +831,7 @@ void SSLClientSocketNSS::HandshakeCallback(PRFileDesc* socket, |
that->UpdateServerCert(); |
} |
-int SSLClientSocketNSS::DoHandshakeRead() { |
+int SSLClientSocketNSS::DoHandshake() { |
EnterFunction(""); |
int net_error = net::OK; |
int rv = SSL_ForceHandshake(nss_fd_); |
@@ -723,7 +853,7 @@ int SSLClientSocketNSS::DoHandshakeRead() { |
// If not done, stay in this state |
if (net_error == ERR_IO_PENDING) { |
- GotoState(STATE_HANDSHAKE_READ); |
+ GotoState(STATE_HANDSHAKE); |
} else { |
LOG(ERROR) << "handshake failed; NSS error code " << prerr |
<< ", net_error " << net_error; |
@@ -744,7 +874,8 @@ int SSLClientSocketNSS::DoVerifyCert(int result) { |
flags |= X509Certificate::VERIFY_EV_CERT; |
verifier_.reset(new CertVerifier); |
return verifier_->Verify(server_cert_, hostname_, flags, |
- &server_cert_verify_result_, &io_callback_); |
+ &server_cert_verify_result_, |
+ &handshake_io_callback_); |
} |
// Derived from AuthCertificateCallback() in |
@@ -805,46 +936,43 @@ int SSLClientSocketNSS::DoVerifyCertComplete(int result) { |
// TODO(ukai): we may not need this call because it is now harmless to have an |
// session with a bad cert. |
InvalidateSessionIfBadCertificate(); |
- // Exit DoLoop and return the result to the caller to Connect. |
- DCHECK(next_state_ == STATE_NONE); |
+ // Exit DoHandshakeLoop and return the result to the caller to Connect. |
+ DCHECK(next_handshake_state_ == STATE_NONE); |
return result; |
} |
int SSLClientSocketNSS::DoPayloadRead() { |
- EnterFunction(user_buf_len_); |
- int rv = PR_Read(nss_fd_, user_buf_->data(), user_buf_len_); |
+ EnterFunction(user_read_buf_len_); |
+ DCHECK(user_read_buf_); |
+ DCHECK(user_read_buf_len_ > 0); |
+ int rv = PR_Read(nss_fd_, user_read_buf_->data(), user_read_buf_len_); |
if (rv >= 0) { |
- LogData(user_buf_->data(), rv); |
- user_buf_ = NULL; |
+ LogData(user_read_buf_->data(), rv); |
LeaveFunction(""); |
return rv; |
} |
PRErrorCode prerr = PR_GetError(); |
if (prerr == PR_WOULD_BLOCK_ERROR) { |
- GotoState(STATE_PAYLOAD_READ); |
LeaveFunction(""); |
return ERR_IO_PENDING; |
} |
- user_buf_ = NULL; |
LeaveFunction(""); |
return NetErrorFromNSPRError(prerr); |
} |
int SSLClientSocketNSS::DoPayloadWrite() { |
- EnterFunction(user_buf_len_); |
- int rv = PR_Write(nss_fd_, user_buf_->data(), user_buf_len_); |
+ EnterFunction(user_write_buf_len_); |
+ DCHECK(user_write_buf_); |
+ int rv = PR_Write(nss_fd_, user_write_buf_->data(), user_write_buf_len_); |
if (rv >= 0) { |
- LogData(user_buf_->data(), rv); |
- user_buf_ = NULL; |
+ LogData(user_write_buf_->data(), rv); |
LeaveFunction(""); |
return rv; |
} |
PRErrorCode prerr = PR_GetError(); |
if (prerr == PR_WOULD_BLOCK_ERROR) { |
- GotoState(STATE_PAYLOAD_WRITE); |
return ERR_IO_PENDING; |
} |
- user_buf_ = NULL; |
LeaveFunction(""); |
return NetErrorFromNSPRError(prerr); |
} |