Index: net/socket/socket_test_util.cc |
diff --git a/net/socket/socket_test_util.cc b/net/socket/socket_test_util.cc |
index 275a403a237c8c3da5bdda3d3ae8da5aa32e2e84..4fa31a5d55587d591907d628d057e08bd6e2cfe5 100644 |
--- a/net/socket/socket_test_util.cc |
+++ b/net/socket/socket_test_util.cc |
@@ -253,6 +253,7 @@ DelayedSocketData::DelayedSocketData( |
MockWrite* writes, size_t writes_count) |
: StaticSocketDataProvider(reads, reads_count, writes, writes_count), |
write_delay_(write_delay), |
+ read_in_progress_(false), |
ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { |
DCHECK_GE(write_delay_, 0); |
} |
@@ -262,6 +263,7 @@ DelayedSocketData::DelayedSocketData( |
size_t reads_count, MockWrite* writes, size_t writes_count) |
: StaticSocketDataProvider(reads, reads_count, writes, writes_count), |
write_delay_(write_delay), |
+ read_in_progress_(false), |
ALLOW_THIS_IN_INITIALIZER_LIST(weak_factory_(this)) { |
DCHECK_GE(write_delay_, 0); |
set_connect_data(connect); |
@@ -271,20 +273,23 @@ DelayedSocketData::~DelayedSocketData() { |
} |
void DelayedSocketData::ForceNextRead() { |
+ DCHECK(read_in_progress_); |
write_delay_ = 0; |
CompleteRead(); |
} |
MockRead DelayedSocketData::GetNextRead() { |
- if (write_delay_ > 0) |
- return MockRead(true, ERR_IO_PENDING); |
- return StaticSocketDataProvider::GetNextRead(); |
+ MockRead out = MockRead(true, ERR_IO_PENDING); |
+ if (write_delay_ <= 0) |
+ out = StaticSocketDataProvider::GetNextRead(); |
+ read_in_progress_ = (out.result == ERR_IO_PENDING); |
+ return out; |
} |
MockWriteResult DelayedSocketData::OnWrite(const std::string& data) { |
MockWriteResult rv = StaticSocketDataProvider::OnWrite(data); |
// Now that our write has completed, we can allow reads to continue. |
- if (!--write_delay_) |
+ if (!--write_delay_ && read_in_progress_) |
MessageLoop::current()->PostDelayedTask( |
FROM_HERE, |
base::Bind(&DelayedSocketData::CompleteRead, |
@@ -295,12 +300,13 @@ MockWriteResult DelayedSocketData::OnWrite(const std::string& data) { |
void DelayedSocketData::Reset() { |
set_socket(NULL); |
+ read_in_progress_ = false; |
weak_factory_.InvalidateWeakPtrs(); |
StaticSocketDataProvider::Reset(); |
} |
void DelayedSocketData::CompleteRead() { |
- if (socket()) |
+ if (socket() && read_in_progress_) |
socket()->OnReadComplete(GetNextRead()); |
} |
@@ -353,6 +359,7 @@ MockRead OrderedSocketData::GetNextRead() { |
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1 |
<< ": Read " << read_index(); |
DumpMockRead(next_read); |
+ blocked_ = (next_read.result == ERR_IO_PENDING); |
return StaticSocketDataProvider::GetNextRead(); |
} |
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1 |
@@ -394,7 +401,7 @@ void OrderedSocketData::Reset() { |
} |
void OrderedSocketData::CompleteRead() { |
- if (socket()) { |
+ if (socket() && blocked_) { |
NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_; |
socket()->OnReadComplete(GetNextRead()); |
} |
@@ -582,18 +589,6 @@ void MockClientSocketFactory::ResetNextMockIndexes() { |
mock_ssl_data_.ResetNextIndex(); |
} |
-MockTCPClientSocket* MockClientSocketFactory::GetMockTCPClientSocket( |
- size_t index) const { |
- DCHECK_LT(index, tcp_client_sockets_.size()); |
- return tcp_client_sockets_[index]; |
-} |
- |
-MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket( |
- size_t index) const { |
- DCHECK_LT(index, ssl_client_sockets_.size()); |
- return ssl_client_sockets_[index]; |
-} |
- |
DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket( |
DatagramSocket::BindType bind_type, |
const RandIntCallback& rand_int_cb, |
@@ -602,7 +597,6 @@ DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket( |
SocketDataProvider* data_provider = mock_data_.GetNext(); |
MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log); |
data_provider->set_socket(socket); |
- udp_client_sockets_.push_back(socket); |
return socket; |
} |
@@ -614,7 +608,6 @@ StreamSocket* MockClientSocketFactory::CreateTransportClientSocket( |
MockTCPClientSocket* socket = |
new MockTCPClientSocket(addresses, net_log, data_provider); |
data_provider->set_socket(socket); |
- tcp_client_sockets_.push_back(socket); |
return socket; |
} |
@@ -627,7 +620,6 @@ SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( |
MockSSLClientSocket* socket = |
new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, |
ssl_host_info, mock_ssl_data_.GetNext()); |
- ssl_client_sockets_.push_back(socket); |
return socket; |
} |