Chromium Code Reviews| Index: net/socket/ssl_client_socket_nss.cc |
| diff --git a/net/socket/ssl_client_socket_nss.cc b/net/socket/ssl_client_socket_nss.cc |
| index 05549da933c6b8fb882b612c003514bc9340411e..ca0485714937356bf64695d2bc743a0f9cadd2ca 100644 |
| --- a/net/socket/ssl_client_socket_nss.cc |
| +++ b/net/socket/ssl_client_socket_nss.cc |
| @@ -81,15 +81,18 @@ namespace net { |
| #if 1 |
| #define EnterFunction(x) |
| #define LeaveFunction(x) |
| -#define GotoState(s) next_state_ = s |
| +#define GotoState(s) next_handshake_state_ = s |
| #define LogData(s, len) |
| #else |
| #define EnterFunction(x) LOG(INFO) << (void *)this << " " << __FUNCTION__ << \ |
| - " enter " << x << "; next_state " << next_state_ |
| + " enter " << x << \ |
| + "; next_handshake_state " << next_handshake_state_ |
| #define LeaveFunction(x) LOG(INFO) << (void *)this << " " << __FUNCTION__ << \ |
| - " leave " << x << "; next_state " << next_state_ |
| + " leave " << x << \ |
| + "; next_handshake_state " << next_handshake_state_ |
| #define GotoState(s) do { LOG(INFO) << (void *)this << " " << __FUNCTION__ << \ |
| - " jump to state " << s; next_state_ = s; } while (0) |
| + " jump to state " << s; \ |
| + next_handshake_state_ = s; } while (0) |
| #define LogData(s, len) LOG(INFO) << (void *)this << " " << __FUNCTION__ << \ |
| " data [" << std::string(s, len) << "]"; |
| @@ -193,15 +196,17 @@ SSLClientSocketNSS::SSLClientSocketNSS(ClientSocket* transport_socket, |
| buffer_recv_callback_(this, &SSLClientSocketNSS::BufferRecvComplete), |
| transport_send_busy_(false), |
| transport_recv_busy_(false), |
| - io_callback_(this, &SSLClientSocketNSS::OnIOComplete), |
| + handshake_io_callback_(this, &SSLClientSocketNSS::OnHandshakeIOComplete), |
| transport_(transport_socket), |
| hostname_(hostname), |
| ssl_config_(ssl_config), |
| user_connect_callback_(NULL), |
| - user_callback_(NULL), |
| - user_buf_len_(0), |
| + user_read_callback_(NULL), |
| + user_write_callback_(NULL), |
| + user_read_buf_len_(0), |
| + user_write_buf_len_(0), |
| completed_handshake_(false), |
| - next_state_(STATE_NONE), |
| + next_handshake_state_(STATE_NONE), |
| nss_fd_(NULL), |
| nss_bufs_(NULL) { |
| EnterFunction(""); |
| @@ -229,10 +234,12 @@ int SSLClientSocketNSS::Init() { |
| int SSLClientSocketNSS::Connect(CompletionCallback* callback) { |
| EnterFunction(""); |
| DCHECK(transport_.get()); |
| - DCHECK(next_state_ == STATE_NONE); |
| - DCHECK(!user_callback_); |
| + DCHECK(next_handshake_state_ == STATE_NONE); |
| + DCHECK(!user_read_callback_); |
| + DCHECK(!user_write_callback_); |
| DCHECK(!user_connect_callback_); |
| - DCHECK(!user_buf_); |
| + DCHECK(!user_read_buf_); |
| + DCHECK(!user_write_buf_); |
| if (Init() != OK) { |
| NOTREACHED() << "Couldn't initialize nss"; |
| @@ -321,8 +328,8 @@ int SSLClientSocketNSS::Connect(CompletionCallback* callback) { |
| // Tell SSL we're a client; needed if not letting NSPR do socket I/O |
| SSL_ResetHandshake(nss_fd_, 0); |
| - GotoState(STATE_HANDSHAKE_READ); |
| - rv = DoLoop(OK); |
| + GotoState(STATE_HANDSHAKE); |
| + rv = DoHandshakeLoop(OK); |
| if (rv == ERR_IO_PENDING) |
| user_connect_callback_ = callback; |
| @@ -348,7 +355,7 @@ void SSLClientSocketNSS::Disconnect() { |
| } |
| // Shut down anything that may call us back (through buffer_send_callback_, |
| - // buffer_recv_callback, or io_callback_). |
| + // buffer_recv_callback, or handshake_io_callback_). |
| verifier_.reset(); |
| transport_->Disconnect(); |
| @@ -356,9 +363,12 @@ void SSLClientSocketNSS::Disconnect() { |
| transport_send_busy_ = false; |
| transport_recv_busy_ = false; |
| user_connect_callback_ = NULL; |
| - user_callback_ = NULL; |
| - user_buf_ = NULL; |
| - user_buf_len_ = 0; |
| + user_read_callback_ = NULL; |
| + user_write_callback_ = NULL; |
| + user_read_buf_ = NULL; |
| + user_read_buf_len_ = 0; |
| + user_write_buf_ = NULL; |
| + user_write_buf_len_ = 0; |
| server_cert_ = NULL; |
| server_cert_verify_result_.Reset(); |
| completed_handshake_ = false; |
| @@ -398,38 +408,48 @@ int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, |
| CompletionCallback* callback) { |
| EnterFunction(buf_len); |
| DCHECK(completed_handshake_); |
| - DCHECK(next_state_ == STATE_NONE); |
| - DCHECK(!user_callback_); |
| + DCHECK(next_handshake_state_ == STATE_NONE); |
| + DCHECK(!user_read_callback_); |
| DCHECK(!user_connect_callback_); |
| - DCHECK(!user_buf_); |
| + DCHECK(!user_read_buf_); |
| + DCHECK(nss_bufs_); |
| - user_buf_ = buf; |
| - user_buf_len_ = buf_len; |
| + user_read_buf_ = buf; |
| + user_read_buf_len_ = buf_len; |
| + |
| + int rv = DoReadLoop(OK); |
| - GotoState(STATE_PAYLOAD_READ); |
| - int rv = DoLoop(OK); |
| if (rv == ERR_IO_PENDING) |
|
wtc
2009/10/14 01:54:15
Please add braces {} to "if" when "else" has brace
|
| - user_callback_ = callback; |
| + user_read_callback_ = callback; |
| + else { |
| + user_read_buf_ = NULL; |
| + user_read_buf_len_ = 0; |
| + } |
| LeaveFunction(rv); |
| return rv; |
| } |
| int SSLClientSocketNSS::Write(IOBuffer* buf, int buf_len, |
| - CompletionCallback* callback) { |
| + CompletionCallback* callback) { |
| EnterFunction(buf_len); |
| DCHECK(completed_handshake_); |
| - DCHECK(next_state_ == STATE_NONE); |
| - DCHECK(!user_callback_); |
| + DCHECK(next_handshake_state_ == STATE_NONE); |
| + DCHECK(!user_write_callback_); |
| DCHECK(!user_connect_callback_); |
| - DCHECK(!user_buf_); |
| + DCHECK(!user_write_buf_); |
| + DCHECK(nss_bufs_); |
| + |
| + user_write_buf_ = buf; |
| + user_write_buf_len_ = buf_len; |
| - user_buf_ = buf; |
| - user_buf_len_ = buf_len; |
| + int rv = DoWriteLoop(OK); |
| - GotoState(STATE_PAYLOAD_WRITE); |
| - int rv = DoLoop(OK); |
| if (rv == ERR_IO_PENDING) |
| - user_callback_ = callback; |
| + user_write_callback_ = callback; |
|
wtc
2009/10/14 01:54:15
Same here: both "if" and "else" should have braces
|
| + else { |
| + user_write_buf_ = NULL; |
| + user_write_buf_len_ = 0; |
| + } |
| LeaveFunction(rv); |
| return rv; |
| } |
| @@ -491,15 +511,32 @@ void SSLClientSocketNSS::GetSSLCertRequestInfo( |
| // TODO(wtc): implement this. |
| } |
| -void SSLClientSocketNSS::DoCallback(int rv) { |
| +void SSLClientSocketNSS::DoReadCallback(int rv) { |
| EnterFunction(rv); |
| DCHECK(rv != ERR_IO_PENDING); |
| - DCHECK(user_callback_); |
| + DCHECK(user_read_callback_); |
| - // Since Run may result in Read being called, clear |user_callback_| up front. |
| - CompletionCallback* c = user_callback_; |
| - user_callback_ = NULL; |
| - user_buf_ = NULL; |
| + // Since Run may result in Read being called, clear |user_read_callback_| |
| + // up front. |
| + CompletionCallback* c = user_read_callback_; |
| + user_read_callback_ = NULL; |
| + user_read_buf_ = NULL; |
| + user_read_buf_len_ = 0; |
| + c->Run(rv); |
| + LeaveFunction(""); |
| +} |
| + |
| +void SSLClientSocketNSS::DoWriteCallback(int rv) { |
| + EnterFunction(rv); |
| + DCHECK(rv != ERR_IO_PENDING); |
| + DCHECK(user_write_callback_); |
| + |
| + // Since Run may result in Write being called, clear |user_write_callback_| |
| + // up front. |
| + CompletionCallback* c = user_write_callback_; |
| + user_write_callback_ = NULL; |
| + user_write_buf_ = NULL; |
| + user_write_buf_len_ = 0; |
| c->Run(rv); |
| LeaveFunction(""); |
| } |
| @@ -516,24 +553,71 @@ void SSLClientSocketNSS::DoConnectCallback(int rv) { |
| DCHECK_NE(rv, ERR_IO_PENDING); |
| DCHECK(user_connect_callback_); |
| - // Since Run may result in Read being called, clear |user_connect_callback_| |
| - // up front. |
| CompletionCallback* c = user_connect_callback_; |
| user_connect_callback_ = NULL; |
| c->Run(rv > OK ? OK : rv); |
| LeaveFunction(""); |
| } |
| -void SSLClientSocketNSS::OnIOComplete(int result) { |
| +void SSLClientSocketNSS::OnHandshakeIOComplete(int result) { |
| EnterFunction(result); |
| - int rv = DoLoop(result); |
| - if (rv != ERR_IO_PENDING) { |
| - if (user_callback_) { |
| - DoCallback(rv); |
| - } else if (user_connect_callback_) { |
| - DoConnectCallback(rv); |
| - } |
| + int rv = DoHandshakeLoop(result); |
| + if (rv != ERR_IO_PENDING) |
| + DoConnectCallback(rv); |
| + LeaveFunction(""); |
| +} |
| + |
| +void SSLClientSocketNSS::OnSendComplete(int result) { |
| + EnterFunction(result); |
| + if (next_handshake_state_ != STATE_NONE) { |
| + // In handshake phase. |
| + OnHandshakeIOComplete(result); |
| + LeaveFunction(""); |
| + return; |
| + } |
| + |
| + // OnSendComplete may need to call DoPayloadRead while the renegotiation |
| + // handshake is in progress. |
| + int rv_read = ERR_IO_PENDING; |
| + int rv_write = ERR_IO_PENDING; |
| + bool network_moved; |
| + do { |
| + if (user_read_buf_) |
| + rv_read = DoPayloadRead(); |
| + if (user_write_buf_) |
| + rv_write = DoPayloadWrite(); |
| + network_moved = DoTransportIO(); |
| + } while (rv_read == ERR_IO_PENDING && |
| + rv_write == ERR_IO_PENDING && |
| + network_moved); |
| + |
| + if (user_read_buf_ && rv_read != ERR_IO_PENDING) |
| + DoReadCallback(rv_read); |
| + if (user_write_buf_ && rv_write != ERR_IO_PENDING) |
| + DoWriteCallback(rv_write); |
| + |
| + LeaveFunction(""); |
| +} |
| + |
| +void SSLClientSocketNSS::OnRecvComplete(int result) { |
| + EnterFunction(result); |
| + if (next_handshake_state_ != STATE_NONE) { |
| + // In handshake phase. |
| + OnHandshakeIOComplete(result); |
| + LeaveFunction(""); |
| + return; |
| } |
| + |
| + // Network layer received some data, check if client requested to read |
| + // decrypted data. |
| + if (!user_read_buf_) { |
| + LeaveFunction(""); |
| + return; |
| + } |
| + |
| + int rv = DoReadLoop(result); |
| + if (rv != ERR_IO_PENDING) |
| + DoReadCallback(rv); |
| LeaveFunction(""); |
| } |
| @@ -549,6 +633,19 @@ static PRErrorCode MapErrorToNSS(int result) { |
| } |
| // Do network I/O between the given buffer and the given socket. |
| +// Return true if some I/O performed, false otherwise (error or ERR_IO_PENDING) |
| +bool SSLClientSocketNSS::DoTransportIO() { |
| + EnterFunction(""); |
| + bool network_moved = false; |
| + if (nss_bufs_ != NULL) { |
| + int nsent = BufferSend(); |
| + int nreceived = BufferRecv(); |
| + network_moved = (nsent > 0 || nreceived >= 0); |
| + } |
| + LeaveFunction(network_moved); |
| + return network_moved; |
| +} |
| + |
| // Return 0 for EOF, |
| // > 0 for bytes transferred immediately, |
| // < 0 for error (or the non-error ERR_IO_PENDING). |
| @@ -580,7 +677,7 @@ void SSLClientSocketNSS::BufferSendComplete(int result) { |
| EnterFunction(result); |
| memio_PutWriteResult(nss_bufs_, result); |
| transport_send_busy_ = false; |
| - OnIOComplete(result); |
| + OnSendComplete(result); |
| LeaveFunction(""); |
| } |
| @@ -621,29 +718,28 @@ void SSLClientSocketNSS::BufferRecvComplete(int result) { |
| recv_buffer_ = NULL; |
| memio_PutReadResult(nss_bufs_, result); |
| transport_recv_busy_ = false; |
| - OnIOComplete(result); |
| + OnRecvComplete(result); |
| LeaveFunction(""); |
| } |
| -int SSLClientSocketNSS::DoLoop(int last_io_result) { |
| +int SSLClientSocketNSS::DoHandshakeLoop(int last_io_result) { |
| EnterFunction(last_io_result); |
| bool network_moved; |
| int rv = last_io_result; |
| do { |
| - network_moved = false; |
| // Default to STATE_NONE for next state. |
| // (This is a quirk carried over from the windows |
| // implementation. It makes reading the logs a bit harder.) |
| // State handlers can and often do call GotoState just |
| // to stay in the current state. |
| - State state = next_state_; |
| + State state = next_handshake_state_; |
| GotoState(STATE_NONE); |
| switch (state) { |
| case STATE_NONE: |
| // we're just pumping data between the buffer and the network |
| break; |
| - case STATE_HANDSHAKE_READ: |
| - rv = DoHandshakeRead(); |
| + case STATE_HANDSHAKE: |
| + rv = DoHandshake(); |
| break; |
| case STATE_VERIFY_CERT: |
| DCHECK(rv == OK); |
| @@ -652,12 +748,6 @@ int SSLClientSocketNSS::DoLoop(int last_io_result) { |
| case STATE_VERIFY_CERT_COMPLETE: |
| rv = DoVerifyCertComplete(rv); |
| break; |
| - case STATE_PAYLOAD_READ: |
| - rv = DoPayloadRead(); |
| - break; |
| - case STATE_PAYLOAD_WRITE: |
| - rv = DoPayloadWrite(); |
| - break; |
| default: |
| rv = ERR_UNEXPECTED; |
| NOTREACHED() << "unexpected state"; |
| @@ -665,13 +755,53 @@ int SSLClientSocketNSS::DoLoop(int last_io_result) { |
| } |
| // Do the actual network I/O |
| - if (nss_bufs_ != NULL) { |
| - int nsent = BufferSend(); |
| - int nreceived = BufferRecv(); |
| - network_moved = (nsent > 0 || nreceived >= 0); |
| - } |
| + network_moved = DoTransportIO(); |
| } while ((rv != ERR_IO_PENDING || network_moved) && |
| - next_state_ != STATE_NONE); |
| + next_handshake_state_ != STATE_NONE); |
| + LeaveFunction(""); |
| + return rv; |
| +} |
| + |
| +int SSLClientSocketNSS::DoReadLoop(int result) { |
| + EnterFunction(""); |
| + DCHECK(completed_handshake_); |
| + DCHECK(next_handshake_state_ == STATE_NONE); |
| + |
| + if (result < 0) |
| + return result; |
| + |
| + if (!nss_bufs_) |
| + return ERR_UNEXPECTED; |
| + |
| + bool network_moved; |
| + int rv; |
| + do { |
| + rv = DoPayloadRead(); |
| + network_moved = DoTransportIO(); |
| + } while (rv == ERR_IO_PENDING && network_moved); |
| + |
| + LeaveFunction(""); |
| + return rv; |
| +} |
| + |
| +int SSLClientSocketNSS::DoWriteLoop(int result) { |
| + EnterFunction(""); |
| + DCHECK(completed_handshake_); |
| + DCHECK(next_handshake_state_ == STATE_NONE); |
| + |
| + if (result < 0) |
| + return result; |
| + |
| + if (!nss_bufs_) |
| + return ERR_UNEXPECTED; |
| + |
| + bool network_moved; |
| + int rv; |
| + do { |
| + rv = DoPayloadWrite(); |
| + network_moved = DoTransportIO(); |
| + } while (rv == ERR_IO_PENDING && network_moved); |
| + |
| LeaveFunction(""); |
| return rv; |
| } |
| @@ -701,7 +831,7 @@ void SSLClientSocketNSS::HandshakeCallback(PRFileDesc* socket, |
| that->UpdateServerCert(); |
| } |
| -int SSLClientSocketNSS::DoHandshakeRead() { |
| +int SSLClientSocketNSS::DoHandshake() { |
| EnterFunction(""); |
| int net_error = net::OK; |
| int rv = SSL_ForceHandshake(nss_fd_); |
| @@ -723,7 +853,7 @@ int SSLClientSocketNSS::DoHandshakeRead() { |
| // If not done, stay in this state |
| if (net_error == ERR_IO_PENDING) { |
| - GotoState(STATE_HANDSHAKE_READ); |
| + GotoState(STATE_HANDSHAKE); |
| } else { |
| LOG(ERROR) << "handshake failed; NSS error code " << prerr |
| << ", net_error " << net_error; |
| @@ -744,7 +874,8 @@ int SSLClientSocketNSS::DoVerifyCert(int result) { |
| flags |= X509Certificate::VERIFY_EV_CERT; |
| verifier_.reset(new CertVerifier); |
| return verifier_->Verify(server_cert_, hostname_, flags, |
| - &server_cert_verify_result_, &io_callback_); |
| + &server_cert_verify_result_, |
| + &handshake_io_callback_); |
| } |
| // Derived from AuthCertificateCallback() in |
| @@ -805,46 +936,43 @@ int SSLClientSocketNSS::DoVerifyCertComplete(int result) { |
| // TODO(ukai): we may not need this call because it is now harmless to have an |
| // session with a bad cert. |
| InvalidateSessionIfBadCertificate(); |
| - // Exit DoLoop and return the result to the caller to Connect. |
| - DCHECK(next_state_ == STATE_NONE); |
| + // Exit DoHandshakeLoop and return the result to the caller to Connect. |
| + DCHECK(next_handshake_state_ == STATE_NONE); |
| return result; |
| } |
| int SSLClientSocketNSS::DoPayloadRead() { |
| - EnterFunction(user_buf_len_); |
| - int rv = PR_Read(nss_fd_, user_buf_->data(), user_buf_len_); |
| + EnterFunction(user_read_buf_len_); |
| + DCHECK(user_read_buf_); |
| + DCHECK(user_read_buf_len_ > 0); |
| + int rv = PR_Read(nss_fd_, user_read_buf_->data(), user_read_buf_len_); |
| if (rv >= 0) { |
| - LogData(user_buf_->data(), rv); |
| - user_buf_ = NULL; |
| + LogData(user_read_buf_->data(), rv); |
| LeaveFunction(""); |
| return rv; |
| } |
| PRErrorCode prerr = PR_GetError(); |
| if (prerr == PR_WOULD_BLOCK_ERROR) { |
| - GotoState(STATE_PAYLOAD_READ); |
| LeaveFunction(""); |
| return ERR_IO_PENDING; |
| } |
| - user_buf_ = NULL; |
| LeaveFunction(""); |
| return NetErrorFromNSPRError(prerr); |
| } |
| int SSLClientSocketNSS::DoPayloadWrite() { |
| - EnterFunction(user_buf_len_); |
| - int rv = PR_Write(nss_fd_, user_buf_->data(), user_buf_len_); |
| + EnterFunction(user_write_buf_len_); |
| + DCHECK(user_write_buf_); |
| + int rv = PR_Write(nss_fd_, user_write_buf_->data(), user_write_buf_len_); |
| if (rv >= 0) { |
| - LogData(user_buf_->data(), rv); |
| - user_buf_ = NULL; |
| + LogData(user_write_buf_->data(), rv); |
| LeaveFunction(""); |
| return rv; |
| } |
| PRErrorCode prerr = PR_GetError(); |
| if (prerr == PR_WOULD_BLOCK_ERROR) { |
| - GotoState(STATE_PAYLOAD_WRITE); |
| return ERR_IO_PENDING; |
| } |
| - user_buf_ = NULL; |
| LeaveFunction(""); |
| return NetErrorFromNSPRError(prerr); |
| } |