Index: net/socket/socket_test_util.cc |
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc |
index 24447569f82035f08b86b6acb2d85d2fbe95d5de..16f50d66f16ab611fb49d8a7267306fa0cf842d6 100644 |
--- a/net/socket/socket_test_util.cc |
+++ b/net/socket/socket_test_util.cc |
@@ -7,8 +7,9 @@ |
#include <algorithm> |
#include <vector> |
- |
#include "base/basictypes.h" |
+#include "base/bind.h" |
+#include "base/bind_helpers.h" |
#include "base/compiler_specific.h" |
#include "base/message_loop.h" |
#include "base/time.h" |
@@ -630,7 +631,7 @@ void MockClientSocketFactory::ClearSSLSessionCache() { |
} |
MockClientSocket::MockClientSocket(net::NetLog* net_log) |
- : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), |
+ : ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)), |
connected_(false), |
net_log_(NetLog::Source(), net_log) { |
} |
@@ -702,15 +703,26 @@ MockClientSocket::~MockClientSocket() {} |
void MockClientSocket::RunCallbackAsync(net::OldCompletionCallback* callback, |
int result) { |
MessageLoop::current()->PostTask(FROM_HERE, |
- method_factory_.NewRunnableMethod( |
- &MockClientSocket::RunCallback, callback, result)); |
+ base::Bind(&MockClientSocket::RunOldCallback, weak_factory_.GetWeakPtr(), |
+ callback, result)); |
+} |
+void MockClientSocket::RunCallbackAsync(const net::CompletionCallback& callback, |
+ int result) { |
+ MessageLoop::current()->PostTask(FROM_HERE, |
+ base::Bind(&MockClientSocket::RunCallback, weak_factory_.GetWeakPtr(), |
+ callback, result)); |
} |
-void MockClientSocket::RunCallback(net::OldCompletionCallback* callback, |
- int result) { |
+void MockClientSocket::RunOldCallback(net::OldCompletionCallback* callback, |
+ int result) { |
if (callback) |
callback->Run(result); |
} |
+void MockClientSocket::RunCallback(const net::CompletionCallback& callback, |
+ int result) { |
+ if (!callback.is_null()) |
+ callback.Run(result); |
+} |
MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, |
net::NetLog* net_log, |
@@ -796,6 +808,19 @@ int MockTCPClientSocket::Connect(net::OldCompletionCallback* callback) { |
} |
return data_->connect_data().result; |
} |
+int MockTCPClientSocket::Connect(const net::CompletionCallback& callback) { |
+ if (connected_) |
+ return net::OK; |
+ |
+ connected_ = true; |
+ peer_closed_connection_ = false; |
+ if (data_->connect_data().async) { |
+ RunCallbackAsync(callback, data_->connect_data().result); |
+ return net::ERR_IO_PENDING; |
+ } |
+ |
+ return data_->connect_data().result; |
+} |
void MockTCPClientSocket::Disconnect() { |
MockClientSocket::Disconnect(); |
@@ -852,7 +877,7 @@ void MockTCPClientSocket::OnReadComplete(const MockRead& data) { |
net::OldCompletionCallback* callback = pending_callback_; |
int rv = CompleteRead(); |
- RunCallback(callback, rv); |
+ RunOldCallback(callback, rv); |
} |
int MockTCPClientSocket::CompleteRead() { |
@@ -1005,6 +1030,19 @@ int DeterministicMockTCPClientSocket::Connect( |
} |
return data_->connect_data().result; |
} |
+int DeterministicMockTCPClientSocket::Connect( |
+ const net::CompletionCallback& callback) { |
+ if (connected_) |
+ return net::OK; |
+ |
+ connected_ = true; |
+ if (data_->connect_data().async) { |
+ RunCallbackAsync(callback, data_->connect_data().result); |
+ return net::ERR_IO_PENDING; |
+ } |
+ |
+ return data_->connect_data().result; |
+} |
void DeterministicMockTCPClientSocket::Disconnect() { |
MockClientSocket::Disconnect(); |
@@ -1036,15 +1074,17 @@ base::TimeDelta DeterministicMockTCPClientSocket::GetConnectTimeMicros() const { |
void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} |
-class MockSSLClientSocket::ConnectCallback |
- : public net::OldCompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> { |
+class MockSSLClientSocket::OldConnectCallback |
+ : public net::OldCompletionCallbackImpl< |
+ MockSSLClientSocket::OldConnectCallback> { |
public: |
- ConnectCallback(MockSSLClientSocket *ssl_client_socket, |
- net::OldCompletionCallback* user_callback, |
- int rv) |
+ OldConnectCallback(MockSSLClientSocket *ssl_client_socket, |
+ net::OldCompletionCallback* user_callback, |
+ int rv) |
: ALLOW_THIS_IN_INITIALIZER_LIST( |
- net::OldCompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>( |
- this, &ConnectCallback::Wrapper)), |
+ net::OldCompletionCallbackImpl< |
+ MockSSLClientSocket::OldConnectCallback>( |
+ this, &OldConnectCallback::Wrapper)), |
ssl_client_socket_(ssl_client_socket), |
user_callback_(user_callback), |
rv_(rv) { |
@@ -1062,6 +1102,32 @@ class MockSSLClientSocket::ConnectCallback |
net::OldCompletionCallback* user_callback_; |
int rv_; |
}; |
+class MockSSLClientSocket::ConnectCallback { |
+ public: |
+ ConnectCallback(MockSSLClientSocket *ssl_client_socket, |
+ const CompletionCallback& user_callback, |
+ int rv) |
+ : ALLOW_THIS_IN_INITIALIZER_LIST(callback_( |
+ base::Bind(&ConnectCallback::Wrapper, base::Unretained(this)))), |
+ ssl_client_socket_(ssl_client_socket), |
+ user_callback_(user_callback), |
+ rv_(rv) { |
+ } |
+ |
+ const CompletionCallback& callback() const { return callback_; } |
+ |
+ private: |
+ void Wrapper(int rv) { |
+ if (rv_ == net::OK) |
+ ssl_client_socket_->connected_ = true; |
+ user_callback_.Run(rv_); |
+ } |
+ |
+ CompletionCallback callback_; |
+ MockSSLClientSocket* ssl_client_socket_; |
+ CompletionCallback user_callback_; |
+ int rv_; |
+}; |
MockSSLClientSocket::MockSSLClientSocket( |
net::ClientSocketHandle* transport_socket, |
@@ -1093,7 +1159,7 @@ int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, |
} |
int MockSSLClientSocket::Connect(net::OldCompletionCallback* callback) { |
- ConnectCallback* connect_callback = new ConnectCallback( |
+ OldConnectCallback* connect_callback = new OldConnectCallback( |
this, callback, data_->connect.result); |
int rv = transport_->socket()->Connect(connect_callback); |
if (rv == net::OK) { |
@@ -1108,6 +1174,20 @@ int MockSSLClientSocket::Connect(net::OldCompletionCallback* callback) { |
} |
return rv; |
} |
+int MockSSLClientSocket::Connect(const net::CompletionCallback& callback) { |
+ ConnectCallback connect_callback(this, callback, data_->connect.result); |
+ int rv = transport_->socket()->Connect(connect_callback.callback()); |
+ if (rv == net::OK) { |
+ if (data_->connect.result == net::OK) |
+ connected_ = true; |
+ if (data_->connect.async) { |
+ RunCallbackAsync(callback, data_->connect.result); |
+ return net::ERR_IO_PENDING; |
+ } |
+ return data_->connect.result; |
+ } |
+ return rv; |
+} |
void MockSSLClientSocket::Disconnect() { |
MockClientSocket::Disconnect(); |
@@ -1185,7 +1265,7 @@ MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, |
pending_buf_len_(0), |
pending_callback_(NULL), |
net_log_(NetLog::Source(), net_log), |
- ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)) { |
+ ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { |
DCHECK(data_); |
data_->Reset(); |
} |
@@ -1328,8 +1408,8 @@ int MockUDPClientSocket::CompleteRead() { |
void MockUDPClientSocket::RunCallbackAsync(net::OldCompletionCallback* callback, |
int result) { |
MessageLoop::current()->PostTask(FROM_HERE, |
- method_factory_.NewRunnableMethod( |
- &MockUDPClientSocket::RunCallback, callback, result)); |
+ base::Bind(&MockUDPClientSocket::RunCallback, weak_factory_.GetWeakPtr(), |
+ callback, result)); |
} |
void MockUDPClientSocket::RunCallback(net::OldCompletionCallback* callback, |