OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/socket/ssl_client_socket.h" | 5 #include "net/socket/ssl_client_socket.h" |
6 | 6 |
7 #include "base/callback_helpers.h" | 7 #include "base/callback_helpers.h" |
8 #include "base/memory/ref_counted.h" | 8 #include "base/memory/ref_counted.h" |
9 #include "base/run_loop.h" | 9 #include "base/run_loop.h" |
10 #include "base/time/time.h" | 10 #include "base/time/time.h" |
11 #include "net/base/address_list.h" | 11 #include "net/base/address_list.h" |
12 #include "net/base/io_buffer.h" | 12 #include "net/base/io_buffer.h" |
13 #include "net/base/net_errors.h" | 13 #include "net/base/net_errors.h" |
14 #include "net/base/net_log.h" | 14 #include "net/base/net_log.h" |
15 #include "net/base/net_log_unittest.h" | 15 #include "net/base/net_log_unittest.h" |
16 #include "net/base/test_completion_callback.h" | 16 #include "net/base/test_completion_callback.h" |
17 #include "net/base/test_data_directory.h" | 17 #include "net/base/test_data_directory.h" |
18 #include "net/cert/mock_cert_verifier.h" | 18 #include "net/cert/mock_cert_verifier.h" |
19 #include "net/cert/test_root_certs.h" | 19 #include "net/cert/test_root_certs.h" |
20 #include "net/dns/host_resolver.h" | 20 #include "net/dns/host_resolver.h" |
21 #include "net/http/transport_security_state.h" | 21 #include "net/http/transport_security_state.h" |
22 #include "net/socket/client_socket_factory.h" | 22 #include "net/socket/client_socket_factory.h" |
23 #include "net/socket/client_socket_handle.h" | 23 #include "net/socket/client_socket_handle.h" |
24 #include "net/socket/socket_test_util.h" | 24 #include "net/socket/socket_test_util.h" |
25 #include "net/socket/tcp_client_socket.h" | 25 #include "net/socket/tcp_client_socket.h" |
26 #include "net/ssl/default_server_bound_cert_store.h" | 26 #include "net/ssl/channel_id_service.h" |
| 27 #include "net/ssl/default_channel_id_store.h" |
27 #include "net/ssl/ssl_cert_request_info.h" | 28 #include "net/ssl/ssl_cert_request_info.h" |
28 #include "net/ssl/ssl_config_service.h" | 29 #include "net/ssl/ssl_config_service.h" |
29 #include "net/test/cert_test_util.h" | 30 #include "net/test/cert_test_util.h" |
30 #include "net/test/spawned_test_server/spawned_test_server.h" | 31 #include "net/test/spawned_test_server/spawned_test_server.h" |
31 #include "testing/gtest/include/gtest/gtest.h" | 32 #include "testing/gtest/include/gtest/gtest.h" |
32 #include "testing/platform_test.h" | 33 #include "testing/platform_test.h" |
33 | 34 |
34 //----------------------------------------------------------------------------- | 35 //----------------------------------------------------------------------------- |
35 | 36 |
36 namespace net { | 37 namespace net { |
(...skipping 547 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
584 } | 585 } |
585 SetResult(result); | 586 SetResult(result); |
586 } | 587 } |
587 | 588 |
588 StreamSocket* socket_; | 589 StreamSocket* socket_; |
589 CompletionCallback callback_; | 590 CompletionCallback callback_; |
590 | 591 |
591 DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); | 592 DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); |
592 }; | 593 }; |
593 | 594 |
594 // A ServerBoundCertStore that always returns an error when asked for a | 595 // A ChannelIDStore that always returns an error when asked for a |
595 // certificate. | 596 // channel id. |
596 class FailingServerBoundCertStore : public ServerBoundCertStore { | 597 class FailingChannelIDStore : public ChannelIDStore { |
597 virtual int GetServerBoundCert(const std::string& server_identifier, | 598 virtual int GetChannelID(const std::string& server_identifier, |
598 base::Time* expiration_time, | 599 base::Time* expiration_time, |
599 std::string* private_key_result, | 600 std::string* private_key_result, |
600 std::string* cert_result, | 601 std::string* cert_result, |
601 const GetCertCallback& callback) OVERRIDE { | 602 const GetChannelIDCallback& callback) OVERRIDE { |
602 return ERR_UNEXPECTED; | 603 return ERR_UNEXPECTED; |
603 } | 604 } |
604 virtual void SetServerBoundCert(const std::string& server_identifier, | 605 virtual void SetChannelID(const std::string& server_identifier, |
605 base::Time creation_time, | 606 base::Time creation_time, |
606 base::Time expiration_time, | 607 base::Time expiration_time, |
607 const std::string& private_key, | 608 const std::string& private_key, |
608 const std::string& cert) OVERRIDE {} | 609 const std::string& cert) OVERRIDE {} |
609 virtual void DeleteServerBoundCert(const std::string& server_identifier, | 610 virtual void DeleteChannelID(const std::string& server_identifier, |
610 const base::Closure& completion_callback) | 611 const base::Closure& completion_callback) |
611 OVERRIDE {} | 612 OVERRIDE {} |
612 virtual void DeleteAllCreatedBetween(base::Time delete_begin, | 613 virtual void DeleteAllCreatedBetween(base::Time delete_begin, |
613 base::Time delete_end, | 614 base::Time delete_end, |
614 const base::Closure& completion_callback) | 615 const base::Closure& completion_callback) |
615 OVERRIDE {} | 616 OVERRIDE {} |
616 virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} | 617 virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} |
617 virtual void GetAllServerBoundCerts(const GetCertListCallback& callback) | 618 virtual void GetAllChannelIDs(const GetChannelIDListCallback& callback) |
618 OVERRIDE {} | 619 OVERRIDE {} |
619 virtual int GetCertCount() OVERRIDE { return 0; } | 620 virtual int GetChannelIDCount() OVERRIDE { return 0; } |
620 virtual void SetForceKeepSessionState() OVERRIDE {} | 621 virtual void SetForceKeepSessionState() OVERRIDE {} |
621 }; | 622 }; |
622 | 623 |
623 // A ServerBoundCertStore that asynchronously returns an error when asked for a | 624 // A ChannelIDStore that asynchronously returns an error when asked for a |
624 // certificate. | 625 // channel id. |
625 class AsyncFailingServerBoundCertStore : public ServerBoundCertStore { | 626 class AsyncFailingChannelIDStore : public ChannelIDStore { |
626 virtual int GetServerBoundCert(const std::string& server_identifier, | 627 virtual int GetChannelID(const std::string& server_identifier, |
627 base::Time* expiration_time, | 628 base::Time* expiration_time, |
628 std::string* private_key_result, | 629 std::string* private_key_result, |
629 std::string* cert_result, | 630 std::string* cert_result, |
630 const GetCertCallback& callback) OVERRIDE { | 631 const GetChannelIDCallback& callback) OVERRIDE { |
631 base::MessageLoop::current()->PostTask( | 632 base::MessageLoop::current()->PostTask( |
632 FROM_HERE, base::Bind(callback, ERR_UNEXPECTED, | 633 FROM_HERE, base::Bind(callback, ERR_UNEXPECTED, |
633 server_identifier, base::Time(), "", "")); | 634 server_identifier, base::Time(), "", "")); |
634 return ERR_IO_PENDING; | 635 return ERR_IO_PENDING; |
635 } | 636 } |
636 virtual void SetServerBoundCert(const std::string& server_identifier, | 637 virtual void SetChannelID(const std::string& server_identifier, |
637 base::Time creation_time, | 638 base::Time creation_time, |
638 base::Time expiration_time, | 639 base::Time expiration_time, |
639 const std::string& private_key, | 640 const std::string& private_key, |
640 const std::string& cert) OVERRIDE {} | 641 const std::string& cert) OVERRIDE {} |
641 virtual void DeleteServerBoundCert(const std::string& server_identifier, | 642 virtual void DeleteChannelID(const std::string& server_identifier, |
642 const base::Closure& completion_callback) | 643 const base::Closure& completion_callback) |
643 OVERRIDE {} | 644 OVERRIDE {} |
644 virtual void DeleteAllCreatedBetween(base::Time delete_begin, | 645 virtual void DeleteAllCreatedBetween(base::Time delete_begin, |
645 base::Time delete_end, | 646 base::Time delete_end, |
646 const base::Closure& completion_callback) | 647 const base::Closure& completion_callback) |
647 OVERRIDE {} | 648 OVERRIDE {} |
648 virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} | 649 virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} |
649 virtual void GetAllServerBoundCerts(const GetCertListCallback& callback) | 650 virtual void GetAllChannelIDs(const GetChannelIDListCallback& callback) |
650 OVERRIDE {} | 651 OVERRIDE {} |
651 virtual int GetCertCount() OVERRIDE { return 0; } | 652 virtual int GetChannelIDCount() OVERRIDE { return 0; } |
652 virtual void SetForceKeepSessionState() OVERRIDE {} | 653 virtual void SetForceKeepSessionState() OVERRIDE {} |
653 }; | 654 }; |
654 | 655 |
655 class SSLClientSocketTest : public PlatformTest { | 656 class SSLClientSocketTest : public PlatformTest { |
656 public: | 657 public: |
657 SSLClientSocketTest() | 658 SSLClientSocketTest() |
658 : socket_factory_(ClientSocketFactory::GetDefaultFactory()), | 659 : socket_factory_(ClientSocketFactory::GetDefaultFactory()), |
659 cert_verifier_(new MockCertVerifier), | 660 cert_verifier_(new MockCertVerifier), |
660 transport_security_state_(new TransportSecurityState) { | 661 transport_security_state_(new TransportSecurityState) { |
661 cert_verifier_->set_default_result(OK); | 662 cert_verifier_->set_default_result(OK); |
(...skipping 235 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
897 // the server second leg is blocked. | 898 // the server second leg is blocked. |
898 base::RunLoop().RunUntilIdle(); | 899 base::RunLoop().RunUntilIdle(); |
899 EXPECT_FALSE(callback.have_result()); | 900 EXPECT_FALSE(callback.have_result()); |
900 } | 901 } |
901 } | 902 } |
902 }; | 903 }; |
903 | 904 |
904 class SSLClientSocketChannelIDTest : public SSLClientSocketTest { | 905 class SSLClientSocketChannelIDTest : public SSLClientSocketTest { |
905 protected: | 906 protected: |
906 void EnableChannelID() { | 907 void EnableChannelID() { |
907 cert_service_.reset( | 908 channel_id_service_.reset( |
908 new ServerBoundCertService(new DefaultServerBoundCertStore(NULL), | 909 new ChannelIDService(new DefaultChannelIDStore(NULL), |
909 base::MessageLoopProxy::current())); | 910 base::MessageLoopProxy::current())); |
910 context_.server_bound_cert_service = cert_service_.get(); | 911 context_.channel_id_service = channel_id_service_.get(); |
911 } | 912 } |
912 | 913 |
913 void EnableFailingChannelID() { | 914 void EnableFailingChannelID() { |
914 cert_service_.reset(new ServerBoundCertService( | 915 channel_id_service_.reset(new ChannelIDService( |
915 new FailingServerBoundCertStore(), base::MessageLoopProxy::current())); | 916 new FailingChannelIDStore(), base::MessageLoopProxy::current())); |
916 context_.server_bound_cert_service = cert_service_.get(); | 917 context_.channel_id_service = channel_id_service_.get(); |
917 } | 918 } |
918 | 919 |
919 void EnableAsyncFailingChannelID() { | 920 void EnableAsyncFailingChannelID() { |
920 cert_service_.reset(new ServerBoundCertService( | 921 channel_id_service_.reset(new ChannelIDService( |
921 new AsyncFailingServerBoundCertStore(), | 922 new AsyncFailingChannelIDStore(), |
922 base::MessageLoopProxy::current())); | 923 base::MessageLoopProxy::current())); |
923 context_.server_bound_cert_service = cert_service_.get(); | 924 context_.channel_id_service = channel_id_service_.get(); |
924 } | 925 } |
925 | 926 |
926 private: | 927 private: |
927 scoped_ptr<ServerBoundCertService> cert_service_; | 928 scoped_ptr<ChannelIDService> channel_id_service_; |
928 }; | 929 }; |
929 | 930 |
930 //----------------------------------------------------------------------------- | 931 //----------------------------------------------------------------------------- |
931 | 932 |
932 // LogContainsSSLConnectEndEvent returns true if the given index in the given | 933 // LogContainsSSLConnectEndEvent returns true if the given index in the given |
933 // log is an SSL connect end event. The NSS sockets will cork in an attempt to | 934 // log is an SSL connect end event. The NSS sockets will cork in an attempt to |
934 // merge the first application data record with the Finished message when false | 935 // merge the first application data record with the Finished message when false |
935 // starting. However, in order to avoid the server timing out the handshake, | 936 // starting. However, in order to avoid the server timing out the handshake, |
936 // they'll give up waiting for application data and send the Finished after a | 937 // they'll give up waiting for application data and send the Finished after a |
937 // timeout. This means that an SSL connect end event may appear as a socket | 938 // timeout. This means that an SSL connect end event may appear as a socket |
(...skipping 1869 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
2807 ssl_config.channel_id_enabled = true; | 2808 ssl_config.channel_id_enabled = true; |
2808 | 2809 |
2809 int rv; | 2810 int rv; |
2810 ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); | 2811 ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); |
2811 | 2812 |
2812 EXPECT_EQ(ERR_UNEXPECTED, rv); | 2813 EXPECT_EQ(ERR_UNEXPECTED, rv); |
2813 EXPECT_FALSE(sock_->IsConnected()); | 2814 EXPECT_FALSE(sock_->IsConnected()); |
2814 } | 2815 } |
2815 | 2816 |
2816 } // namespace net | 2817 } // namespace net |
OLD | NEW |