Index: net/socket/ssl_client_socket_unittest.cc |
diff --git a/net/socket/ssl_client_socket_unittest.cc b/net/socket/ssl_client_socket_unittest.cc |
index c36e5819665c62a2c18f28a7c0d47755332e7248..bc56abcba59cfb84ff05d1792b0cbee2a2447fac 100644 |
--- a/net/socket/ssl_client_socket_unittest.cc |
+++ b/net/socket/ssl_client_socket_unittest.cc |
@@ -32,6 +32,10 @@ |
#include "testing/gtest/include/gtest/gtest.h" |
#include "testing/platform_test.h" |
+#if defined(OS_WIN) |
+#include "base/win/windows_version.h" |
+#endif |
+ |
//----------------------------------------------------------------------------- |
namespace net { |
@@ -566,6 +570,90 @@ class CountingStreamSocket : public WrappedStreamSocket { |
int write_count_; |
}; |
+// WriteCapturingSocket is a fake StreamSocket that captures all writes and |
+// fails any reads. It is intended to capture the ClientHello. |
+class WriteCapturingSocket : public StreamSocket { |
+ public: |
+ WriteCapturingSocket(net::NetLog* net_log) |
+ : net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_SOCKET)) {} |
+ |
+ const std::vector<uint8_t>& bytes_written() const { return bytes_written_; } |
+ |
+ // StreamSocket implementation: |
+ virtual int Connect(const CompletionCallback& callback) OVERRIDE { |
+ return OK; |
+ } |
+ virtual void Disconnect() OVERRIDE { } |
+ virtual bool IsConnected() const OVERRIDE { |
+ return true; |
+ } |
+ virtual bool IsConnectedAndIdle() const OVERRIDE { |
+ return true; |
+ } |
+ virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { |
+ // NSS requires this method be functional. |
+ IPAddressNumber number; |
+ CHECK(ParseIPLiteralToNumber("127.0.0.1", &number)); |
+ *address = IPEndPoint(number, 443); |
+ return OK; |
+ } |
+ virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return ERR_FAILED; |
+ } |
+ virtual const BoundNetLog& NetLog() const OVERRIDE { |
+ return net_log_; |
+ } |
+ virtual void SetSubresourceSpeculation() OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ } |
+ virtual void SetOmniboxSpeculation() OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ } |
+ virtual bool WasEverUsed() const OVERRIDE { |
+ NOTIMPLEMENTED(); |
+ return false; |
+ } |
+ virtual bool UsingTCPFastOpen() const OVERRIDE { |
+ return false; |
+ } |
+ virtual bool WasNpnNegotiated() const OVERRIDE { |
+ return false; |
+ } |
+ virtual NextProto GetNegotiatedProtocol() const OVERRIDE { |
+ return kProtoUnknown; |
+ } |
+ virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { |
+ return false; |
+ } |
+ |
+ // Socket implementation: |
+ virtual int Read(IOBuffer* buf, |
+ int buf_len, |
+ const CompletionCallback& callback) OVERRIDE { |
+ // Fail the read to stop the handshake at ClientHello. |
+ return ERR_FAILED; |
+ } |
+ virtual int Write(IOBuffer* buf, |
+ int buf_len, |
+ const CompletionCallback& callback) OVERRIDE { |
+ for (int i = 0; i < buf_len; i++) { |
+ bytes_written_.push_back(buf->data()[i]); |
+ } |
+ return buf_len; |
+ } |
+ virtual int SetReceiveBufferSize(int32 size) OVERRIDE { |
+ return 0; |
+ } |
+ virtual int SetSendBufferSize(int32 size) OVERRIDE { |
+ return 0; |
+ } |
+ |
+ private: |
+ BoundNetLog net_log_; |
+ std::vector<uint8_t> bytes_written_; |
+}; |
+ |
// CompletionCallback that will delete the associated StreamSocket when |
// the callback is invoked. |
class DeleteSocketCallback : public TestCompletionCallbackBase { |
@@ -2684,6 +2772,94 @@ TEST_F(SSLClientSocketTest, ReuseStates) { |
// attempt to read one byte extra. |
} |
+#if defined(OS_WIN) |
+ |
+bool IsECDSACipherSuite(uint16_t cipher_suite) { |
+ // RFC 4492. |
+ if (0xc001 <= cipher_suite && cipher_suite <= 0xc00a) |
+ return true; |
+ // RFC 5289. |
+ if (0xc023 <= cipher_suite && cipher_suite <= 0xc026) |
+ return true; |
+ if (0xc02b <= cipher_suite && cipher_suite <= 0xc02e) |
+ return true; |
+ return false; |
Ryan Sleevi
2014/09/02 23:45:01
SSLCipherSuiteToStrings(&key_exchange, ..., cipher
davidben
2014/09/03 17:10:41
Done.
|
+} |
+ |
+// Test that ECDSA is disabled on Windows XP, where ECDSA certificates cannot be |
+// verified. |
+TEST_F(SSLClientSocketTest, DisableECDSAOnXP) { |
+ if (base::win::GetVersion() >= base::win::VERSION_VISTA) { |
+ LOG(INFO) << "Skipping test on this version."; |
+ return; |
+ } |
+ |
+ scoped_ptr<WriteCapturingSocket> transport(new WriteCapturingSocket(&log_)); |
+ WriteCapturingSocket* raw_transport = transport.get(); |
+ |
+ // Handshake up to trying to read the ServerHello. |
+ scoped_ptr<SSLClientSocket> sock(CreateSSLClientSocket( |
+ transport.PassAs<StreamSocket>(), |
+ HostPortPair("example.com", 443), kDefaultSSLConfig)); |
+ TestCompletionCallback callback; |
+ EXPECT_EQ(ERR_FAILED, callback.GetResult(sock->Connect(callback.callback()))); |
+ base::RunLoop().RunUntilIdle(); |
+ |
+ // Parse out the cipher list. This will require that the ClientHello is not |
+ // fragmented before the cipher list because that would be an exceedingly long |
+ // cipher list. |
+ std::vector<uint8_t> client_hello = raw_transport->bytes_written(); |
+ |
+ // TLSPlaintext header: |
+ ASSERT_GE(client_hello.size(), 5u); |
+ EXPECT_EQ(22, client_hello[0]); // type |
+ // Next two bytes are the version. |
+ uint16_t record_length = (client_hello[3] << 8) | client_hello[4]; |
+ // Grab the record body. |
+ ASSERT_GE(client_hello.size(), 5u + record_length); |
+ std::vector<uint8_t> record_body(client_hello.begin() + 5, |
+ client_hello.begin() + 5 + record_length); |
+ |
+ // Handshake header: |
+ ASSERT_GE(record_body.size(), 4u); |
+ EXPECT_EQ(1, record_body[0]); // msg_type |
+ uint32_t length = |
+ (record_body[1] << 16) | (record_body[2] << 8) | record_body[3]; |
+ std::vector<uint8_t> message(record_body.begin() + 4, record_body.end()); |
+ // There cannot be a handshake message after ClientHello, though the |
+ // ClientHello could conceivably be fragmented across two records. |
+ ASSERT_LE(message.size(), length); |
+ |
+ // ClientHello: |
+ |
+ // Skip past the client_version and random. |
+ ASSERT_GE(message.size(), 2u + 32u); |
+ message.erase(message.begin(), message.begin() + 2 + 32); |
+ |
+ // Skip past the session id. |
+ ASSERT_GE(message.size(), 1u); |
+ uint8_t session_id_length = message[0]; |
+ ASSERT_GE(message.size(), 1u + session_id_length); |
+ message.erase(message.begin(), message.begin() + 1 + session_id_length); |
+ |
+ // Get the cipher suite list. |
+ ASSERT_GE(message.size(), 2u); |
+ uint16_t cipher_suites_length = (message[0] << 8) | message[1]; |
+ EXPECT_EQ(0, cipher_suites_length % 2); |
+ ASSERT_GE(message.size(), 2u + cipher_suites_length); |
+ std::vector<uint8_t> cipher_suites( |
+ message.begin() + 2, message.begin() + 2 + cipher_suites_length); |
+ |
+ // Finally, ensure there are no ECDSA cipher suites in there. |
+ for (size_t i = 0; i+1 < cipher_suites.size(); i+=2) { |
+ uint16_t cipher_suite = (cipher_suites[i] << 8) | cipher_suites[i+1]; |
+ EXPECT_FALSE(IsECDSACipherSuite(cipher_suite)) |
+ << "ClientHello advertised " << std::hex << cipher_suite; |
+ } |
+} |
+ |
+#endif // OS_WIN |
+ |
#if defined(USE_OPENSSL) |
TEST_F(SSLClientSocketTest, HandshakeCallbackIsRun_WithFailure) { |