Index: net/http/http_network_transaction_unittest.cc |
=================================================================== |
--- net/http/http_network_transaction_unittest.cc (revision 12708) |
+++ net/http/http_network_transaction_unittest.cc (working copy) |
@@ -6,6 +6,9 @@ |
#include "base/compiler_specific.h" |
#include "net/base/client_socket_factory.h" |
+#include "net/base/completion_callback.h" |
+#include "net/base/ssl_client_socket.h" |
+#include "net/base/ssl_info.h" |
#include "net/base/test_completion_callback.h" |
#include "net/base/upload_data.h" |
#include "net/http/http_auth_handler_ntlm.h" |
@@ -25,7 +28,8 @@ |
struct MockConnect { |
// Asynchronous connection success. |
- MockConnect() : async(true), result(net::OK) { } |
+ MockConnect() : async(true), result(OK) { } |
+ MockConnect(bool a, int r) : async(a), result(r) { } |
wtc
2009/03/30 18:18:57
You may have intentionally named the constructor's
|
bool async; |
int result; |
@@ -62,6 +66,7 @@ |
struct MockSocket { |
MockSocket() : reads(NULL), writes(NULL) { } |
+ MockSocket(MockRead* r, MockWrite* w) : reads(r), writes(w) { } |
wtc
2009/03/30 18:18:57
Not sure if we need this constructor, but it's ok.
|
MockConnect connect; |
MockRead* reads; |
@@ -76,37 +81,35 @@ |
// |
MockSocket* mock_sockets[10]; |
+// MockSSLSockets only need to keep track of the return code from calls to |
+// Connect(). |
+struct MockSSLSocket { |
+ MockSSLSocket(bool async, int result) : connect(async, result) { } |
+ |
+ MockConnect connect; |
+}; |
+MockSSLSocket* mock_ssl_sockets[10]; |
+ |
// Index of the next mock_sockets element to use. |
int mock_sockets_index; |
+int mock_ssl_sockets_index; |
-class MockTCPClientSocket : public net::ClientSocket { |
+class MockClientSocket : public SSLClientSocket { |
public: |
- explicit MockTCPClientSocket(const net::AddressList& addresses) |
- : data_(mock_sockets[mock_sockets_index++]), |
- ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), |
+ explicit MockClientSocket() |
+ : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), |
callback_(NULL), |
- read_index_(0), |
- read_offset_(0), |
- write_index_(0), |
connected_(false) { |
- DCHECK(data_) << "overran mock_sockets array"; |
} |
+ |
// ClientSocket methods: |
- virtual int Connect(net::CompletionCallback* callback) { |
- DCHECK(!callback_); |
- if (connected_) |
- return net::OK; |
- connected_ = true; |
- if (data_->connect.async) { |
- RunCallbackAsync(callback, data_->connect.result); |
- return net::ERR_IO_PENDING; |
- } |
- return data_->connect.result; |
- } |
- virtual int ReconnectIgnoringLastError(net::CompletionCallback* callback) { |
+ virtual int Connect(CompletionCallback* callback) = 0; |
+ |
+ // SSLClientSocket methods: |
+ virtual void GetSSLInfo(SSLInfo* ssl_info) { |
NOTREACHED(); |
- return net::ERR_FAILED; |
} |
+ |
virtual void Disconnect() { |
connected_ = false; |
callback_ = NULL; |
@@ -118,8 +121,65 @@ |
return connected_; |
} |
// Socket methods: |
- virtual int Read(char* buf, int buf_len, net::CompletionCallback* callback) { |
+ virtual int Read(char* buf, int buf_len, |
+ CompletionCallback* callback) = 0; |
+ virtual int Write(const char* buf, int buf_len, |
+ CompletionCallback* callback) = 0; |
+ |
+#if defined(OS_LINUX) |
+ virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen) { |
+ memset(reinterpret_cast<char *>(name), 0, *namelen); |
+ return OK; |
+ } |
+#endif |
+ |
+ |
+ protected: |
+ void RunCallbackAsync(CompletionCallback* callback, int result) { |
+ callback_ = callback; |
+ MessageLoop::current()->PostTask(FROM_HERE, |
+ method_factory_.NewRunnableMethod( |
+ &MockClientSocket::RunCallback, result)); |
+ } |
+ |
+ void RunCallback(int result) { |
+ CompletionCallback* c = callback_; |
+ callback_ = NULL; |
+ if (c) |
+ c->Run(result); |
+ } |
+ |
+ ScopedRunnableMethodFactory<MockClientSocket> method_factory_; |
+ CompletionCallback* callback_; |
+ bool connected_; |
+}; |
+ |
+class MockTCPClientSocket : public MockClientSocket { |
+ public: |
+ explicit MockTCPClientSocket(const AddressList& addresses) |
+ : data_(mock_sockets[mock_sockets_index++]), |
+ read_index_(0), |
+ read_offset_(0), |
+ write_index_(0) { |
+ DCHECK(data_) << "overran mock_sockets array"; |
+ } |
+ |
+ // ClientSocket methods: |
+ virtual int Connect(CompletionCallback* callback) { |
DCHECK(!callback_); |
+ if (connected_) |
+ return OK; |
+ connected_ = true; |
+ if (data_->connect.async) { |
+ RunCallbackAsync(callback, data_->connect.result); |
+ return ERR_IO_PENDING; |
+ } |
+ return data_->connect.result; |
+ } |
+ |
+ // Socket methods: |
+ virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { |
+ DCHECK(!callback_); |
MockRead& r = data_->reads[read_index_]; |
int result = r.result; |
if (r.data) { |
@@ -137,12 +197,13 @@ |
} |
if (r.async) { |
RunCallbackAsync(callback, result); |
- return net::ERR_IO_PENDING; |
+ return ERR_IO_PENDING; |
} |
return result; |
} |
+ |
virtual int Write(const char* buf, int buf_len, |
- net::CompletionCallback* callback) { |
+ CompletionCallback* callback) { |
DCHECK(buf); |
DCHECK(buf_len > 0); |
DCHECK(!callback_); |
@@ -159,49 +220,123 @@ |
std::string actual_data(buf, buf_len); |
EXPECT_EQ(expected_data, actual_data); |
if (expected_data != actual_data) |
- return net::ERR_UNEXPECTED; |
- if (result == net::OK) |
+ return ERR_UNEXPECTED; |
+ if (result == OK) |
result = w.data_len; |
} |
if (w.async) { |
RunCallbackAsync(callback, result); |
- return net::ERR_IO_PENDING; |
+ return ERR_IO_PENDING; |
} |
return result; |
} |
+ |
private: |
- void RunCallbackAsync(net::CompletionCallback* callback, int result) { |
- callback_ = callback; |
- MessageLoop::current()->PostTask(FROM_HERE, |
- method_factory_.NewRunnableMethod( |
- &MockTCPClientSocket::RunCallback, result)); |
- } |
- void RunCallback(int result) { |
- net::CompletionCallback* c = callback_; |
- callback_ = NULL; |
- if (c) |
- c->Run(result); |
- } |
MockSocket* data_; |
- ScopedRunnableMethodFactory<MockTCPClientSocket> method_factory_; |
- net::CompletionCallback* callback_; |
int read_index_; |
int read_offset_; |
int write_index_; |
- bool connected_; |
}; |
-class MockClientSocketFactory : public net::ClientSocketFactory { |
+class MockSSLClientSocket : public MockClientSocket { |
public: |
- virtual net::ClientSocket* CreateTCPClientSocket( |
- const net::AddressList& addresses) { |
+ explicit MockSSLClientSocket( |
+ ClientSocket* transport_socket, |
+ const std::string& hostname, |
+ const SSLConfig& ssl_config) |
+ : transport_(transport_socket), |
+ data_(mock_ssl_sockets[mock_ssl_sockets_index++]) { |
+ DCHECK(data_) << "overran mock_ssl_sockets array"; |
+ } |
+ |
+ ~MockSSLClientSocket() { |
+ Disconnect(); |
+ } |
+ |
+ virtual void GetSSLInfo(SSLInfo* ssl_info) { |
+ ssl_info->Reset(); |
+ } |
+ |
+ friend class ConnectCallback; |
+ class ConnectCallback : |
+ public CompletionCallbackImpl<ConnectCallback> { |
+ public: |
+ ConnectCallback(MockSSLClientSocket *ssl_client_socket, |
+ CompletionCallback* user_callback, |
+ int rv) |
+ : ALLOW_THIS_IN_INITIALIZER_LIST( |
+ CompletionCallbackImpl<ConnectCallback>( |
+ this, &ConnectCallback::Wrapper)), |
+ ssl_client_socket_(ssl_client_socket), |
+ user_callback_(user_callback), |
+ rv_(rv) { |
+ } |
+ |
+ private: |
+ void Wrapper(int rv) { |
+ if (rv_ == OK) |
+ ssl_client_socket_->connected_ = true; |
+ user_callback_->Run(rv_); |
+ delete this; |
+ } |
+ |
+ MockSSLClientSocket* ssl_client_socket_; |
+ CompletionCallback* user_callback_; |
+ int rv_; |
+ }; |
+ |
+ virtual int Connect(CompletionCallback* callback) { |
+ DCHECK(!callback_); |
+ ConnectCallback* connect_callback = new ConnectCallback( |
+ this, callback, data_->connect.result); |
+ int rv = transport_->Connect(connect_callback); |
+ if (rv == OK) { |
+ delete connect_callback; |
+ if (data_->connect.async) { |
+ RunCallbackAsync(callback, data_->connect.result); |
+ return ERR_IO_PENDING; |
+ } |
+ if (data_->connect.result == OK) |
+ connected_ = true; |
+ return data_->connect.result; |
+ } |
+ return rv; |
+ } |
+ |
+ virtual void Disconnect() { |
+ MockClientSocket::Disconnect(); |
+ if (transport_ != NULL) |
+ transport_->Disconnect(); |
+ } |
+ |
+ // Socket methods: |
+ virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { |
+ DCHECK(!callback_); |
+ return transport_->Read(buf, buf_len, callback); |
+ } |
+ |
+ virtual int Write(const char* buf, int buf_len, |
+ CompletionCallback* callback) { |
+ DCHECK(!callback_); |
+ return transport_->Write(buf, buf_len, callback); |
+ } |
+ |
+ private: |
+ scoped_ptr<ClientSocket> transport_; |
+ MockSSLSocket* data_; |
+}; |
+ |
+class MockClientSocketFactory : public ClientSocketFactory { |
+ public: |
+ virtual ClientSocket* CreateTCPClientSocket( |
+ const AddressList& addresses) { |
return new MockTCPClientSocket(addresses); |
} |
- virtual net::SSLClientSocket* CreateSSLClientSocket( |
- net::ClientSocket* transport_socket, |
+ virtual SSLClientSocket* CreateSSLClientSocket( |
+ ClientSocket* transport_socket, |
const std::string& hostname, |
- const net::SSLConfig& ssl_config) { |
- return NULL; |
+ const SSLConfig& ssl_config) { |
+ return new MockSSLClientSocket(transport_socket, hostname, ssl_config); |
} |
}; |
@@ -229,6 +364,7 @@ |
PlatformTest::SetUp(); |
mock_sockets[0] = NULL; |
mock_sockets_index = 0; |
+ mock_ssl_sockets_index = 0; |
} |
virtual void TearDown() { |
@@ -2711,4 +2847,139 @@ |
EXPECT_FALSE(trans->response_.vary_data.is_valid()); |
} |
+// Test HTTPS connections to a site with a bad certificate |
+TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificate) { |
+ scoped_ptr<ProxyService> proxy_service(CreateNullProxyService()); |
+ scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction( |
+ CreateSession(proxy_service.get()), &mock_socket_factory)); |
+ |
+ HttpRequestInfo request; |
+ request.method = "GET"; |
+ request.url = GURL("https://www.google.com/"); |
+ request.load_flags = 0; |
+ |
+ MockWrite data_writes[] = { |
+ MockWrite("GET / HTTP/1.1\r\n" |
+ "Host: www.google.com\r\n" |
+ "Connection: keep-alive\r\n\r\n"), |
+ }; |
+ |
+ MockRead data_reads[] = { |
+ MockRead("HTTP/1.0 200 OK\r\n"), |
+ MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), |
+ MockRead("Content-Length: 100\r\n\r\n"), |
+ MockRead(false, OK), |
+ }; |
+ |
+ MockSocket ssl_bad_certificate; |
+ MockSocket data(data_reads, data_writes); |
+ MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); |
+ MockSSLSocket ssl(true, OK); |
+ |
+ mock_sockets[0] = &ssl_bad_certificate; |
+ mock_sockets[1] = &data; |
+ mock_sockets[2] = NULL; |
+ |
+ mock_ssl_sockets[0] = &ssl_bad; |
+ mock_ssl_sockets[1] = &ssl; |
+ mock_ssl_sockets[2] = NULL; |
+ |
+ TestCompletionCallback callback; |
+ |
+ int rv = trans->Start(&request, &callback); |
+ EXPECT_EQ(ERR_IO_PENDING, rv); |
+ |
+ rv = callback.WaitForResult(); |
+ EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); |
+ |
+ rv = trans->RestartIgnoringLastError(&callback); |
+ EXPECT_EQ(ERR_IO_PENDING, rv); |
+ |
+ rv = callback.WaitForResult(); |
+ EXPECT_EQ(OK, rv); |
+ |
+ const HttpResponseInfo* response = trans->GetResponseInfo(); |
+ |
+ EXPECT_FALSE(response == NULL); |
+ EXPECT_EQ(100, response->headers->GetContentLength()); |
+} |
+ |
+// Test HTTPS connections to a site with a bad certificate, going through a |
+// proxy |
+TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificateViaProxy) { |
+ scoped_ptr<ProxyService> proxy_service( |
+ CreateFixedProxyService("myproxy:70")); |
+ |
+ HttpRequestInfo request; |
+ request.method = "GET"; |
+ request.url = GURL("https://www.google.com/"); |
+ request.load_flags = 0; |
+ |
+ MockWrite proxy_writes[] = { |
+ MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" |
+ "Host: www.google.com\r\n\r\n"), |
+ }; |
+ |
+ MockRead proxy_reads[] = { |
+ MockRead("HTTP/1.0 200 Connected\r\n\r\n"), |
+ MockRead(false, net::OK) |
+ }; |
+ |
+ MockWrite data_writes[] = { |
+ MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" |
+ "Host: www.google.com\r\n\r\n"), |
+ MockWrite("GET / HTTP/1.1\r\n" |
+ "Host: www.google.com\r\n" |
+ "Connection: keep-alive\r\n\r\n"), |
+ }; |
+ |
+ MockRead data_reads[] = { |
+ MockRead("HTTP/1.0 200 Connected\r\n\r\n"), |
+ MockRead("HTTP/1.0 200 OK\r\n"), |
+ MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), |
+ MockRead("Content-Length: 100\r\n\r\n"), |
+ MockRead(false, OK), |
+ }; |
+ |
+ MockSocket ssl_bad_certificate(proxy_reads, proxy_writes); |
+ MockSocket data(data_reads, data_writes); |
+ MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); |
+ MockSSLSocket ssl(true, OK); |
+ |
+ mock_sockets[0] = &ssl_bad_certificate; |
wtc
2009/03/30 18:18:57
It's a little confusing for a socket to be named "
|
+ mock_sockets[1] = &data; |
+ mock_sockets[2] = NULL; |
+ |
+ mock_ssl_sockets[0] = &ssl_bad; |
+ mock_ssl_sockets[1] = &ssl; |
+ mock_ssl_sockets[2] = NULL; |
+ |
+ TestCompletionCallback callback; |
+ |
+ for (int i = 0; i < 2; i++) { |
+ mock_sockets_index = 0; |
+ mock_ssl_sockets_index = 0; |
+ |
+ scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction( |
+ CreateSession(proxy_service.get()), &mock_socket_factory)); |
+ |
+ int rv = trans->Start(&request, &callback); |
+ EXPECT_EQ(ERR_IO_PENDING, rv); |
+ |
+ rv = callback.WaitForResult(); |
+ EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); |
+ |
+ rv = trans->RestartIgnoringLastError(&callback); |
+ EXPECT_EQ(ERR_IO_PENDING, rv); |
+ |
+ rv = callback.WaitForResult(); |
+ EXPECT_EQ(OK, rv); |
+ |
+ const HttpResponseInfo* response = trans->GetResponseInfo(); |
+ |
+ EXPECT_FALSE(response == NULL); |
+ EXPECT_EQ(100, response->headers->GetContentLength()); |
+ } |
+} |
+ |
} // namespace net |