Index: net/socket/ssl_client_socket_unittest.cc |
diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc |
index 05844cd0ef327d6c01ccbc847826cd1e3e0a7b76..8ba24b05aedfd3c99877528e2308247d6b97eee6 100644 |
--- a/net/socket/ssl_client_socket_unittest.cc |
+++ b/net/socket/ssl_client_socket_unittest.cc |
@@ -22,6 +22,7 @@ |
#include "net/socket/client_socket_handle.h" |
#include "net/socket/socket_test_util.h" |
#include "net/socket/tcp_client_socket.h" |
+#include "net/ssl/default_server_bound_cert_store.h" |
#include "net/ssl/ssl_cert_request_info.h" |
#include "net/ssl/ssl_config_service.h" |
#include "net/test/cert_test_util.h" |
@@ -557,6 +558,35 @@ class DeleteSocketCallback : public TestCompletionCallbackBase { |
DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); |
}; |
+// A ServerBoundCertStore that always returns an error when asked for a |
+// certificate. |
+class FailingServerBoundCertStore : public ServerBoundCertStore { |
+ virtual int GetServerBoundCert(const std::string& server_identifier, |
+ base::Time* expiration_time, |
+ std::string* private_key_result, |
+ std::string* cert_result, |
+ const GetCertCallback& callback) OVERRIDE { |
+ return ERR_UNEXPECTED; |
+ } |
+ virtual void SetServerBoundCert(const std::string& server_identifier, |
+ base::Time creation_time, |
+ base::Time expiration_time, |
+ const std::string& private_key, |
+ const std::string& cert) OVERRIDE {} |
+ virtual void DeleteServerBoundCert(const std::string& server_identifier, |
+ const base::Closure& completion_callback) |
+ OVERRIDE {} |
+ virtual void DeleteAllCreatedBetween(base::Time delete_begin, |
+ base::Time delete_end, |
+ const base::Closure& completion_callback) |
+ OVERRIDE {} |
+ virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} |
+ virtual void GetAllServerBoundCerts(const GetCertListCallback& callback) |
+ OVERRIDE {} |
+ virtual int GetCertCount() OVERRIDE { return 0; } |
+ virtual void SetForceKeepSessionState() OVERRIDE {} |
+}; |
+ |
class SSLClientSocketTest : public PlatformTest { |
public: |
SSLClientSocketTest() |
@@ -713,6 +743,24 @@ class SSLClientSocketFalseStartTest : public SSLClientSocketTest { |
} |
}; |
+class SSLClientSocketChannelIDTest : public SSLClientSocketTest { |
wtc
2014/05/06 17:56:24
Nit: I think it is a good idea to copy the Connect
haavardm
2014/05/07 13:53:43
I did that to follow the code pattern of the rest
|
+ protected: |
+ void EnabledChannelID() { |
Ryan Sleevi
2014/05/06 18:45:26
these should be "EnableChannelID()" / "EnableFaili
haavardm
2014/05/07 13:53:43
Done.
|
+ cert_service_.reset( |
+ new ServerBoundCertService(new DefaultServerBoundCertStore(NULL), |
+ base::MessageLoopProxy::current())); |
+ context_.server_bound_cert_service = cert_service_.get(); |
+ } |
+ |
+ void EnabledFailingChannelID() { |
+ cert_service_.reset(new ServerBoundCertService( |
+ new FailingServerBoundCertStore(), base::MessageLoopProxy::current())); |
+ context_.server_bound_cert_service = cert_service_.get(); |
+ } |
+ |
+ scoped_ptr<ServerBoundCertService> cert_service_; |
wtc
2014/05/06 17:56:24
The cert_service_ member should be private.
|
+}; |
+ |
//----------------------------------------------------------------------------- |
// LogContainsSSLConnectEndEvent returns true if the given index in the given |
@@ -2363,4 +2411,78 @@ TEST_F(SSLClientSocketFalseStartTest, NoForwardSecrecy) { |
TestFalseStart(server_options, client_config, false); |
} |
+// Connect to a server using channel id. It should allow the connection. |
+TEST_F(SSLClientSocketChannelIDTest, SendChannelID) { |
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, |
+ SpawnedTestServer::kLocalhost, |
+ base::FilePath()); |
+ |
+ SpawnedTestServer::SSLOptions ssl_options; |
wtc
2014/05/06 17:56:24
Remove the ssl_options variable in both unit tests
haavardm
2014/05/07 13:53:43
Done.
|
+ ASSERT_TRUE(test_server.Start()); |
+ |
+ AddressList addr; |
+ ASSERT_TRUE(test_server.GetAddressList(&addr)); |
+ |
+ TestCompletionCallback callback; |
+ scoped_ptr<StreamSocket> transport( |
+ new TCPClientSocket(addr, NULL, NetLog::Source())); |
wtc
2014/05/06 17:56:24
Just wanted to confirm that you intended to pass N
haavardm
2014/05/07 13:53:43
Yes, that was as intended. Some tests only pass if
|
+ int rv = transport->Connect(callback.callback()); |
+ if (rv == ERR_IO_PENDING) |
+ rv = callback.WaitForResult(); |
+ EXPECT_EQ(OK, rv); |
+ |
+ EnabledChannelID(); |
+ SSLConfig ssl_config = kDefaultSSLConfig; |
+ ssl_config.channel_id_enabled = true; |
+ |
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( |
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); |
wtc
2014/05/06 17:56:24
1. BUG: the variable ssl_config is set but not use
haavardm
2014/05/07 13:53:43
Hm, the evils of copy and paste.. Yes the tests pa
|
+ |
+ rv = sock->Connect(callback.callback()); |
wtc
2014/05/06 17:56:24
Just wanted to point out that you omitted them che
|
+ if (rv == ERR_IO_PENDING) |
+ rv = callback.WaitForResult(); |
+ EXPECT_EQ(OK, rv); |
wtc
2014/05/06 17:56:24
Nit: fix indentation.
haavardm
2014/05/07 13:53:43
Done.
|
+ EXPECT_TRUE(sock->IsConnected()); |
+ EXPECT_TRUE(sock->WasChannelIDSent()); |
+ |
+ sock->Disconnect(); |
+ EXPECT_FALSE(sock->IsConnected()); |
+} |
+ |
+// Connect to a server using channel id but without sending a key. It should |
+// fail. |
+TEST_F(SSLClientSocketChannelIDTest, FailingChannelID) { |
+ SpawnedTestServer test_server(SpawnedTestServer::TYPE_HTTPS, |
+ SpawnedTestServer::kLocalhost, |
+ base::FilePath()); |
+ |
+ SpawnedTestServer::SSLOptions ssl_options; |
+ ASSERT_TRUE(test_server.Start()); |
+ |
+ AddressList addr; |
+ ASSERT_TRUE(test_server.GetAddressList(&addr)); |
+ |
+ TestCompletionCallback callback; |
+ scoped_ptr<StreamSocket> transport( |
+ new TCPClientSocket(addr, NULL, NetLog::Source())); |
+ int rv = transport->Connect(callback.callback()); |
+ if (rv == ERR_IO_PENDING) |
+ rv = callback.WaitForResult(); |
+ EXPECT_EQ(OK, rv); |
+ |
+ EnabledFailingChannelID(); |
+ SSLConfig ssl_config = kDefaultSSLConfig; |
+ ssl_config.channel_id_enabled = true; |
+ |
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( |
+ transport.Pass(), test_server.host_port_pair(), kDefaultSSLConfig)); |
+ |
+ rv = sock->Connect(callback.callback()); |
+ if (rv == ERR_IO_PENDING) |
+ rv = callback.WaitForResult(); |
+ |
+ EXPECT_EQ(ERR_UNEXPECTED, rv); |
+ EXPECT_FALSE(sock->IsConnected()); |
+} |
+ |
} // namespace net |