Chromium Code Reviews| Index: net/http/http_network_transaction_unittest.cc |
| =================================================================== |
| --- net/http/http_network_transaction_unittest.cc (revision 12708) |
| +++ net/http/http_network_transaction_unittest.cc (working copy) |
| @@ -6,6 +6,9 @@ |
| #include "base/compiler_specific.h" |
| #include "net/base/client_socket_factory.h" |
| +#include "net/base/completion_callback.h" |
| +#include "net/base/ssl_client_socket.h" |
| +#include "net/base/ssl_info.h" |
| #include "net/base/test_completion_callback.h" |
| #include "net/base/upload_data.h" |
| #include "net/http/http_auth_handler_ntlm.h" |
| @@ -25,7 +28,8 @@ |
| struct MockConnect { |
| // Asynchronous connection success. |
| - MockConnect() : async(true), result(net::OK) { } |
| + MockConnect() : async(true), result(OK) { } |
| + MockConnect(bool a, int r) : async(a), result(r) { } |
|
wtc
2009/03/30 18:18:57
You may have intentionally named the constructor's
|
| bool async; |
| int result; |
| @@ -62,6 +66,7 @@ |
| struct MockSocket { |
| MockSocket() : reads(NULL), writes(NULL) { } |
| + MockSocket(MockRead* r, MockWrite* w) : reads(r), writes(w) { } |
|
wtc
2009/03/30 18:18:57
Not sure if we need this constructor, but it's ok.
|
| MockConnect connect; |
| MockRead* reads; |
| @@ -76,37 +81,35 @@ |
| // |
| MockSocket* mock_sockets[10]; |
| +// MockSSLSockets only need to keep track of the return code from calls to |
| +// Connect(). |
| +struct MockSSLSocket { |
| + MockSSLSocket(bool async, int result) : connect(async, result) { } |
| + |
| + MockConnect connect; |
| +}; |
| +MockSSLSocket* mock_ssl_sockets[10]; |
| + |
| // Index of the next mock_sockets element to use. |
| int mock_sockets_index; |
| +int mock_ssl_sockets_index; |
| -class MockTCPClientSocket : public net::ClientSocket { |
| +class MockClientSocket : public SSLClientSocket { |
| public: |
| - explicit MockTCPClientSocket(const net::AddressList& addresses) |
| - : data_(mock_sockets[mock_sockets_index++]), |
| - ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), |
| + explicit MockClientSocket() |
| + : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), |
| callback_(NULL), |
| - read_index_(0), |
| - read_offset_(0), |
| - write_index_(0), |
| connected_(false) { |
| - DCHECK(data_) << "overran mock_sockets array"; |
| } |
| + |
| // ClientSocket methods: |
| - virtual int Connect(net::CompletionCallback* callback) { |
| - DCHECK(!callback_); |
| - if (connected_) |
| - return net::OK; |
| - connected_ = true; |
| - if (data_->connect.async) { |
| - RunCallbackAsync(callback, data_->connect.result); |
| - return net::ERR_IO_PENDING; |
| - } |
| - return data_->connect.result; |
| - } |
| - virtual int ReconnectIgnoringLastError(net::CompletionCallback* callback) { |
| + virtual int Connect(CompletionCallback* callback) = 0; |
| + |
| + // SSLClientSocket methods: |
| + virtual void GetSSLInfo(SSLInfo* ssl_info) { |
| NOTREACHED(); |
| - return net::ERR_FAILED; |
| } |
| + |
| virtual void Disconnect() { |
| connected_ = false; |
| callback_ = NULL; |
| @@ -118,8 +121,65 @@ |
| return connected_; |
| } |
| // Socket methods: |
| - virtual int Read(char* buf, int buf_len, net::CompletionCallback* callback) { |
| + virtual int Read(char* buf, int buf_len, |
| + CompletionCallback* callback) = 0; |
| + virtual int Write(const char* buf, int buf_len, |
| + CompletionCallback* callback) = 0; |
| + |
| +#if defined(OS_LINUX) |
| + virtual int GetPeerName(struct sockaddr *name, socklen_t *namelen) { |
| + memset(reinterpret_cast<char *>(name), 0, *namelen); |
| + return OK; |
| + } |
| +#endif |
| + |
| + |
| + protected: |
| + void RunCallbackAsync(CompletionCallback* callback, int result) { |
| + callback_ = callback; |
| + MessageLoop::current()->PostTask(FROM_HERE, |
| + method_factory_.NewRunnableMethod( |
| + &MockClientSocket::RunCallback, result)); |
| + } |
| + |
| + void RunCallback(int result) { |
| + CompletionCallback* c = callback_; |
| + callback_ = NULL; |
| + if (c) |
| + c->Run(result); |
| + } |
| + |
| + ScopedRunnableMethodFactory<MockClientSocket> method_factory_; |
| + CompletionCallback* callback_; |
| + bool connected_; |
| +}; |
| + |
| +class MockTCPClientSocket : public MockClientSocket { |
| + public: |
| + explicit MockTCPClientSocket(const AddressList& addresses) |
| + : data_(mock_sockets[mock_sockets_index++]), |
| + read_index_(0), |
| + read_offset_(0), |
| + write_index_(0) { |
| + DCHECK(data_) << "overran mock_sockets array"; |
| + } |
| + |
| + // ClientSocket methods: |
| + virtual int Connect(CompletionCallback* callback) { |
| DCHECK(!callback_); |
| + if (connected_) |
| + return OK; |
| + connected_ = true; |
| + if (data_->connect.async) { |
| + RunCallbackAsync(callback, data_->connect.result); |
| + return ERR_IO_PENDING; |
| + } |
| + return data_->connect.result; |
| + } |
| + |
| + // Socket methods: |
| + virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { |
| + DCHECK(!callback_); |
| MockRead& r = data_->reads[read_index_]; |
| int result = r.result; |
| if (r.data) { |
| @@ -137,12 +197,13 @@ |
| } |
| if (r.async) { |
| RunCallbackAsync(callback, result); |
| - return net::ERR_IO_PENDING; |
| + return ERR_IO_PENDING; |
| } |
| return result; |
| } |
| + |
| virtual int Write(const char* buf, int buf_len, |
| - net::CompletionCallback* callback) { |
| + CompletionCallback* callback) { |
| DCHECK(buf); |
| DCHECK(buf_len > 0); |
| DCHECK(!callback_); |
| @@ -159,49 +220,123 @@ |
| std::string actual_data(buf, buf_len); |
| EXPECT_EQ(expected_data, actual_data); |
| if (expected_data != actual_data) |
| - return net::ERR_UNEXPECTED; |
| - if (result == net::OK) |
| + return ERR_UNEXPECTED; |
| + if (result == OK) |
| result = w.data_len; |
| } |
| if (w.async) { |
| RunCallbackAsync(callback, result); |
| - return net::ERR_IO_PENDING; |
| + return ERR_IO_PENDING; |
| } |
| return result; |
| } |
| + |
| private: |
| - void RunCallbackAsync(net::CompletionCallback* callback, int result) { |
| - callback_ = callback; |
| - MessageLoop::current()->PostTask(FROM_HERE, |
| - method_factory_.NewRunnableMethod( |
| - &MockTCPClientSocket::RunCallback, result)); |
| - } |
| - void RunCallback(int result) { |
| - net::CompletionCallback* c = callback_; |
| - callback_ = NULL; |
| - if (c) |
| - c->Run(result); |
| - } |
| MockSocket* data_; |
| - ScopedRunnableMethodFactory<MockTCPClientSocket> method_factory_; |
| - net::CompletionCallback* callback_; |
| int read_index_; |
| int read_offset_; |
| int write_index_; |
| - bool connected_; |
| }; |
| -class MockClientSocketFactory : public net::ClientSocketFactory { |
| +class MockSSLClientSocket : public MockClientSocket { |
| public: |
| - virtual net::ClientSocket* CreateTCPClientSocket( |
| - const net::AddressList& addresses) { |
| + explicit MockSSLClientSocket( |
| + ClientSocket* transport_socket, |
| + const std::string& hostname, |
| + const SSLConfig& ssl_config) |
| + : transport_(transport_socket), |
| + data_(mock_ssl_sockets[mock_ssl_sockets_index++]) { |
| + DCHECK(data_) << "overran mock_ssl_sockets array"; |
| + } |
| + |
| + ~MockSSLClientSocket() { |
| + Disconnect(); |
| + } |
| + |
| + virtual void GetSSLInfo(SSLInfo* ssl_info) { |
| + ssl_info->Reset(); |
| + } |
| + |
| + friend class ConnectCallback; |
| + class ConnectCallback : |
| + public CompletionCallbackImpl<ConnectCallback> { |
| + public: |
| + ConnectCallback(MockSSLClientSocket *ssl_client_socket, |
| + CompletionCallback* user_callback, |
| + int rv) |
| + : ALLOW_THIS_IN_INITIALIZER_LIST( |
| + CompletionCallbackImpl<ConnectCallback>( |
| + this, &ConnectCallback::Wrapper)), |
| + ssl_client_socket_(ssl_client_socket), |
| + user_callback_(user_callback), |
| + rv_(rv) { |
| + } |
| + |
| + private: |
| + void Wrapper(int rv) { |
| + if (rv_ == OK) |
| + ssl_client_socket_->connected_ = true; |
| + user_callback_->Run(rv_); |
| + delete this; |
| + } |
| + |
| + MockSSLClientSocket* ssl_client_socket_; |
| + CompletionCallback* user_callback_; |
| + int rv_; |
| + }; |
| + |
| + virtual int Connect(CompletionCallback* callback) { |
| + DCHECK(!callback_); |
| + ConnectCallback* connect_callback = new ConnectCallback( |
| + this, callback, data_->connect.result); |
| + int rv = transport_->Connect(connect_callback); |
| + if (rv == OK) { |
| + delete connect_callback; |
| + if (data_->connect.async) { |
| + RunCallbackAsync(callback, data_->connect.result); |
| + return ERR_IO_PENDING; |
| + } |
| + if (data_->connect.result == OK) |
| + connected_ = true; |
| + return data_->connect.result; |
| + } |
| + return rv; |
| + } |
| + |
| + virtual void Disconnect() { |
| + MockClientSocket::Disconnect(); |
| + if (transport_ != NULL) |
| + transport_->Disconnect(); |
| + } |
| + |
| + // Socket methods: |
| + virtual int Read(char* buf, int buf_len, CompletionCallback* callback) { |
| + DCHECK(!callback_); |
| + return transport_->Read(buf, buf_len, callback); |
| + } |
| + |
| + virtual int Write(const char* buf, int buf_len, |
| + CompletionCallback* callback) { |
| + DCHECK(!callback_); |
| + return transport_->Write(buf, buf_len, callback); |
| + } |
| + |
| + private: |
| + scoped_ptr<ClientSocket> transport_; |
| + MockSSLSocket* data_; |
| +}; |
| + |
| +class MockClientSocketFactory : public ClientSocketFactory { |
| + public: |
| + virtual ClientSocket* CreateTCPClientSocket( |
| + const AddressList& addresses) { |
| return new MockTCPClientSocket(addresses); |
| } |
| - virtual net::SSLClientSocket* CreateSSLClientSocket( |
| - net::ClientSocket* transport_socket, |
| + virtual SSLClientSocket* CreateSSLClientSocket( |
| + ClientSocket* transport_socket, |
| const std::string& hostname, |
| - const net::SSLConfig& ssl_config) { |
| - return NULL; |
| + const SSLConfig& ssl_config) { |
| + return new MockSSLClientSocket(transport_socket, hostname, ssl_config); |
| } |
| }; |
| @@ -229,6 +364,7 @@ |
| PlatformTest::SetUp(); |
| mock_sockets[0] = NULL; |
| mock_sockets_index = 0; |
| + mock_ssl_sockets_index = 0; |
| } |
| virtual void TearDown() { |
| @@ -2711,4 +2847,139 @@ |
| EXPECT_FALSE(trans->response_.vary_data.is_valid()); |
| } |
| +// Test HTTPS connections to a site with a bad certificate |
| +TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificate) { |
| + scoped_ptr<ProxyService> proxy_service(CreateNullProxyService()); |
| + scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction( |
| + CreateSession(proxy_service.get()), &mock_socket_factory)); |
| + |
| + HttpRequestInfo request; |
| + request.method = "GET"; |
| + request.url = GURL("https://www.google.com/"); |
| + request.load_flags = 0; |
| + |
| + MockWrite data_writes[] = { |
| + MockWrite("GET / HTTP/1.1\r\n" |
| + "Host: www.google.com\r\n" |
| + "Connection: keep-alive\r\n\r\n"), |
| + }; |
| + |
| + MockRead data_reads[] = { |
| + MockRead("HTTP/1.0 200 OK\r\n"), |
| + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), |
| + MockRead("Content-Length: 100\r\n\r\n"), |
| + MockRead(false, OK), |
| + }; |
| + |
| + MockSocket ssl_bad_certificate; |
| + MockSocket data(data_reads, data_writes); |
| + MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); |
| + MockSSLSocket ssl(true, OK); |
| + |
| + mock_sockets[0] = &ssl_bad_certificate; |
| + mock_sockets[1] = &data; |
| + mock_sockets[2] = NULL; |
| + |
| + mock_ssl_sockets[0] = &ssl_bad; |
| + mock_ssl_sockets[1] = &ssl; |
| + mock_ssl_sockets[2] = NULL; |
| + |
| + TestCompletionCallback callback; |
| + |
| + int rv = trans->Start(&request, &callback); |
| + EXPECT_EQ(ERR_IO_PENDING, rv); |
| + |
| + rv = callback.WaitForResult(); |
| + EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); |
| + |
| + rv = trans->RestartIgnoringLastError(&callback); |
| + EXPECT_EQ(ERR_IO_PENDING, rv); |
| + |
| + rv = callback.WaitForResult(); |
| + EXPECT_EQ(OK, rv); |
| + |
| + const HttpResponseInfo* response = trans->GetResponseInfo(); |
| + |
| + EXPECT_FALSE(response == NULL); |
| + EXPECT_EQ(100, response->headers->GetContentLength()); |
| +} |
| + |
| +// Test HTTPS connections to a site with a bad certificate, going through a |
| +// proxy |
| +TEST_F(HttpNetworkTransactionTest, HTTPSBadCertificateViaProxy) { |
| + scoped_ptr<ProxyService> proxy_service( |
| + CreateFixedProxyService("myproxy:70")); |
| + |
| + HttpRequestInfo request; |
| + request.method = "GET"; |
| + request.url = GURL("https://www.google.com/"); |
| + request.load_flags = 0; |
| + |
| + MockWrite proxy_writes[] = { |
| + MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" |
| + "Host: www.google.com\r\n\r\n"), |
| + }; |
| + |
| + MockRead proxy_reads[] = { |
| + MockRead("HTTP/1.0 200 Connected\r\n\r\n"), |
| + MockRead(false, net::OK) |
| + }; |
| + |
| + MockWrite data_writes[] = { |
| + MockWrite("CONNECT www.google.com:443 HTTP/1.1\r\n" |
| + "Host: www.google.com\r\n\r\n"), |
| + MockWrite("GET / HTTP/1.1\r\n" |
| + "Host: www.google.com\r\n" |
| + "Connection: keep-alive\r\n\r\n"), |
| + }; |
| + |
| + MockRead data_reads[] = { |
| + MockRead("HTTP/1.0 200 Connected\r\n\r\n"), |
| + MockRead("HTTP/1.0 200 OK\r\n"), |
| + MockRead("Content-Type: text/html; charset=iso-8859-1\r\n"), |
| + MockRead("Content-Length: 100\r\n\r\n"), |
| + MockRead(false, OK), |
| + }; |
| + |
| + MockSocket ssl_bad_certificate(proxy_reads, proxy_writes); |
| + MockSocket data(data_reads, data_writes); |
| + MockSSLSocket ssl_bad(true, ERR_CERT_AUTHORITY_INVALID); |
| + MockSSLSocket ssl(true, OK); |
| + |
| + mock_sockets[0] = &ssl_bad_certificate; |
|
wtc
2009/03/30 18:18:57
It's a little confusing for a socket to be named "
|
| + mock_sockets[1] = &data; |
| + mock_sockets[2] = NULL; |
| + |
| + mock_ssl_sockets[0] = &ssl_bad; |
| + mock_ssl_sockets[1] = &ssl; |
| + mock_ssl_sockets[2] = NULL; |
| + |
| + TestCompletionCallback callback; |
| + |
| + for (int i = 0; i < 2; i++) { |
| + mock_sockets_index = 0; |
| + mock_ssl_sockets_index = 0; |
| + |
| + scoped_ptr<HttpTransaction> trans(new HttpNetworkTransaction( |
| + CreateSession(proxy_service.get()), &mock_socket_factory)); |
| + |
| + int rv = trans->Start(&request, &callback); |
| + EXPECT_EQ(ERR_IO_PENDING, rv); |
| + |
| + rv = callback.WaitForResult(); |
| + EXPECT_EQ(ERR_CERT_AUTHORITY_INVALID, rv); |
| + |
| + rv = trans->RestartIgnoringLastError(&callback); |
| + EXPECT_EQ(ERR_IO_PENDING, rv); |
| + |
| + rv = callback.WaitForResult(); |
| + EXPECT_EQ(OK, rv); |
| + |
| + const HttpResponseInfo* response = trans->GetResponseInfo(); |
| + |
| + EXPECT_FALSE(response == NULL); |
| + EXPECT_EQ(100, response->headers->GetContentLength()); |
| + } |
| +} |
| + |
| } // namespace net |