| OLD | NEW |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/socket/ssl_client_socket.h" | 5 #include "net/socket/ssl_client_socket.h" |
| 6 | 6 |
| 7 #include "base/callback_helpers.h" | 7 #include "base/callback_helpers.h" |
| 8 #include "base/memory/ref_counted.h" | 8 #include "base/memory/ref_counted.h" |
| 9 #include "base/run_loop.h" | 9 #include "base/run_loop.h" |
| 10 #include "base/time/time.h" | 10 #include "base/time/time.h" |
| (...skipping 42 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 53 // This is to provide a common base class for subclasses to override specific | 53 // This is to provide a common base class for subclasses to override specific |
| 54 // StreamSocket methods for testing, while still communicating with a 'real' | 54 // StreamSocket methods for testing, while still communicating with a 'real' |
| 55 // StreamSocket. | 55 // StreamSocket. |
| 56 class WrappedStreamSocket : public StreamSocket { | 56 class WrappedStreamSocket : public StreamSocket { |
| 57 public: | 57 public: |
| 58 explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport) | 58 explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport) |
| 59 : transport_(transport.Pass()) {} | 59 : transport_(transport.Pass()) {} |
| 60 virtual ~WrappedStreamSocket() {} | 60 virtual ~WrappedStreamSocket() {} |
| 61 | 61 |
| 62 // StreamSocket implementation: | 62 // StreamSocket implementation: |
| 63 virtual int Connect(const CompletionCallback& callback) OVERRIDE { | 63 virtual int Connect(const CompletionCallback& callback) override { |
| 64 return transport_->Connect(callback); | 64 return transport_->Connect(callback); |
| 65 } | 65 } |
| 66 virtual void Disconnect() OVERRIDE { transport_->Disconnect(); } | 66 virtual void Disconnect() override { transport_->Disconnect(); } |
| 67 virtual bool IsConnected() const OVERRIDE { | 67 virtual bool IsConnected() const override { |
| 68 return transport_->IsConnected(); | 68 return transport_->IsConnected(); |
| 69 } | 69 } |
| 70 virtual bool IsConnectedAndIdle() const OVERRIDE { | 70 virtual bool IsConnectedAndIdle() const override { |
| 71 return transport_->IsConnectedAndIdle(); | 71 return transport_->IsConnectedAndIdle(); |
| 72 } | 72 } |
| 73 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { | 73 virtual int GetPeerAddress(IPEndPoint* address) const override { |
| 74 return transport_->GetPeerAddress(address); | 74 return transport_->GetPeerAddress(address); |
| 75 } | 75 } |
| 76 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { | 76 virtual int GetLocalAddress(IPEndPoint* address) const override { |
| 77 return transport_->GetLocalAddress(address); | 77 return transport_->GetLocalAddress(address); |
| 78 } | 78 } |
| 79 virtual const BoundNetLog& NetLog() const OVERRIDE { | 79 virtual const BoundNetLog& NetLog() const override { |
| 80 return transport_->NetLog(); | 80 return transport_->NetLog(); |
| 81 } | 81 } |
| 82 virtual void SetSubresourceSpeculation() OVERRIDE { | 82 virtual void SetSubresourceSpeculation() override { |
| 83 transport_->SetSubresourceSpeculation(); | 83 transport_->SetSubresourceSpeculation(); |
| 84 } | 84 } |
| 85 virtual void SetOmniboxSpeculation() OVERRIDE { | 85 virtual void SetOmniboxSpeculation() override { |
| 86 transport_->SetOmniboxSpeculation(); | 86 transport_->SetOmniboxSpeculation(); |
| 87 } | 87 } |
| 88 virtual bool WasEverUsed() const OVERRIDE { | 88 virtual bool WasEverUsed() const override { |
| 89 return transport_->WasEverUsed(); | 89 return transport_->WasEverUsed(); |
| 90 } | 90 } |
| 91 virtual bool UsingTCPFastOpen() const OVERRIDE { | 91 virtual bool UsingTCPFastOpen() const override { |
| 92 return transport_->UsingTCPFastOpen(); | 92 return transport_->UsingTCPFastOpen(); |
| 93 } | 93 } |
| 94 virtual bool WasNpnNegotiated() const OVERRIDE { | 94 virtual bool WasNpnNegotiated() const override { |
| 95 return transport_->WasNpnNegotiated(); | 95 return transport_->WasNpnNegotiated(); |
| 96 } | 96 } |
| 97 virtual NextProto GetNegotiatedProtocol() const OVERRIDE { | 97 virtual NextProto GetNegotiatedProtocol() const override { |
| 98 return transport_->GetNegotiatedProtocol(); | 98 return transport_->GetNegotiatedProtocol(); |
| 99 } | 99 } |
| 100 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { | 100 virtual bool GetSSLInfo(SSLInfo* ssl_info) override { |
| 101 return transport_->GetSSLInfo(ssl_info); | 101 return transport_->GetSSLInfo(ssl_info); |
| 102 } | 102 } |
| 103 | 103 |
| 104 // Socket implementation: | 104 // Socket implementation: |
| 105 virtual int Read(IOBuffer* buf, | 105 virtual int Read(IOBuffer* buf, |
| 106 int buf_len, | 106 int buf_len, |
| 107 const CompletionCallback& callback) OVERRIDE { | 107 const CompletionCallback& callback) override { |
| 108 return transport_->Read(buf, buf_len, callback); | 108 return transport_->Read(buf, buf_len, callback); |
| 109 } | 109 } |
| 110 virtual int Write(IOBuffer* buf, | 110 virtual int Write(IOBuffer* buf, |
| 111 int buf_len, | 111 int buf_len, |
| 112 const CompletionCallback& callback) OVERRIDE { | 112 const CompletionCallback& callback) override { |
| 113 return transport_->Write(buf, buf_len, callback); | 113 return transport_->Write(buf, buf_len, callback); |
| 114 } | 114 } |
| 115 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { | 115 virtual int SetReceiveBufferSize(int32 size) override { |
| 116 return transport_->SetReceiveBufferSize(size); | 116 return transport_->SetReceiveBufferSize(size); |
| 117 } | 117 } |
| 118 virtual int SetSendBufferSize(int32 size) OVERRIDE { | 118 virtual int SetSendBufferSize(int32 size) override { |
| 119 return transport_->SetSendBufferSize(size); | 119 return transport_->SetSendBufferSize(size); |
| 120 } | 120 } |
| 121 | 121 |
| 122 protected: | 122 protected: |
| 123 scoped_ptr<StreamSocket> transport_; | 123 scoped_ptr<StreamSocket> transport_; |
| 124 }; | 124 }; |
| 125 | 125 |
| 126 // ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that | 126 // ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that |
| 127 // will ensure a certain amount of data is internally buffered before | 127 // will ensure a certain amount of data is internally buffered before |
| 128 // satisfying a Read() request. It exists to mimic OS-level internal | 128 // satisfying a Read() request. It exists to mimic OS-level internal |
| 129 // buffering, but in a way to guarantee that X number of bytes will be | 129 // buffering, but in a way to guarantee that X number of bytes will be |
| 130 // returned to callers of Read(), regardless of how quickly the OS receives | 130 // returned to callers of Read(), regardless of how quickly the OS receives |
| 131 // them from the TestServer. | 131 // them from the TestServer. |
| 132 class ReadBufferingStreamSocket : public WrappedStreamSocket { | 132 class ReadBufferingStreamSocket : public WrappedStreamSocket { |
| 133 public: | 133 public: |
| 134 explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport); | 134 explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport); |
| 135 virtual ~ReadBufferingStreamSocket() {} | 135 virtual ~ReadBufferingStreamSocket() {} |
| 136 | 136 |
| 137 // Socket implementation: | 137 // Socket implementation: |
| 138 virtual int Read(IOBuffer* buf, | 138 virtual int Read(IOBuffer* buf, |
| 139 int buf_len, | 139 int buf_len, |
| 140 const CompletionCallback& callback) OVERRIDE; | 140 const CompletionCallback& callback) override; |
| 141 | 141 |
| 142 // Sets the internal buffer to |size|. This must not be greater than | 142 // Sets the internal buffer to |size|. This must not be greater than |
| 143 // the largest value supplied to Read() - that is, it does not handle | 143 // the largest value supplied to Read() - that is, it does not handle |
| 144 // having "leftovers" at the end of Read(). | 144 // having "leftovers" at the end of Read(). |
| 145 // Each call to Read() will be prevented from completion until at least | 145 // Each call to Read() will be prevented from completion until at least |
| 146 // |size| data has been read. | 146 // |size| data has been read. |
| 147 // Set to 0 to turn off buffering, causing Read() to transparently | 147 // Set to 0 to turn off buffering, causing Read() to transparently |
| 148 // read via the underlying transport. | 148 // read via the underlying transport. |
| 149 void SetBufferSize(int size); | 149 void SetBufferSize(int size); |
| 150 | 150 |
| (...skipping 109 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 260 | 260 |
| 261 // Simulates synchronously receiving an error during Read() or Write() | 261 // Simulates synchronously receiving an error during Read() or Write() |
| 262 class SynchronousErrorStreamSocket : public WrappedStreamSocket { | 262 class SynchronousErrorStreamSocket : public WrappedStreamSocket { |
| 263 public: | 263 public: |
| 264 explicit SynchronousErrorStreamSocket(scoped_ptr<StreamSocket> transport); | 264 explicit SynchronousErrorStreamSocket(scoped_ptr<StreamSocket> transport); |
| 265 virtual ~SynchronousErrorStreamSocket() {} | 265 virtual ~SynchronousErrorStreamSocket() {} |
| 266 | 266 |
| 267 // Socket implementation: | 267 // Socket implementation: |
| 268 virtual int Read(IOBuffer* buf, | 268 virtual int Read(IOBuffer* buf, |
| 269 int buf_len, | 269 int buf_len, |
| 270 const CompletionCallback& callback) OVERRIDE; | 270 const CompletionCallback& callback) override; |
| 271 virtual int Write(IOBuffer* buf, | 271 virtual int Write(IOBuffer* buf, |
| 272 int buf_len, | 272 int buf_len, |
| 273 const CompletionCallback& callback) OVERRIDE; | 273 const CompletionCallback& callback) override; |
| 274 | 274 |
| 275 // Sets the next Read() call and all future calls to return |error|. | 275 // Sets the next Read() call and all future calls to return |error|. |
| 276 // If there is already a pending asynchronous read, the configured error | 276 // If there is already a pending asynchronous read, the configured error |
| 277 // will not be returned until that asynchronous read has completed and Read() | 277 // will not be returned until that asynchronous read has completed and Read() |
| 278 // is called again. | 278 // is called again. |
| 279 void SetNextReadError(Error error) { | 279 void SetNextReadError(Error error) { |
| 280 DCHECK_GE(0, error); | 280 DCHECK_GE(0, error); |
| 281 have_read_error_ = true; | 281 have_read_error_ = true; |
| 282 pending_read_error_ = error; | 282 pending_read_error_ = error; |
| 283 } | 283 } |
| (...skipping 47 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 331 // deterministic manner (e.g.: independent of the TestServer and the OS's | 331 // deterministic manner (e.g.: independent of the TestServer and the OS's |
| 332 // semantics). | 332 // semantics). |
| 333 class FakeBlockingStreamSocket : public WrappedStreamSocket { | 333 class FakeBlockingStreamSocket : public WrappedStreamSocket { |
| 334 public: | 334 public: |
| 335 explicit FakeBlockingStreamSocket(scoped_ptr<StreamSocket> transport); | 335 explicit FakeBlockingStreamSocket(scoped_ptr<StreamSocket> transport); |
| 336 virtual ~FakeBlockingStreamSocket() {} | 336 virtual ~FakeBlockingStreamSocket() {} |
| 337 | 337 |
| 338 // Socket implementation: | 338 // Socket implementation: |
| 339 virtual int Read(IOBuffer* buf, | 339 virtual int Read(IOBuffer* buf, |
| 340 int buf_len, | 340 int buf_len, |
| 341 const CompletionCallback& callback) OVERRIDE; | 341 const CompletionCallback& callback) override; |
| 342 virtual int Write(IOBuffer* buf, | 342 virtual int Write(IOBuffer* buf, |
| 343 int buf_len, | 343 int buf_len, |
| 344 const CompletionCallback& callback) OVERRIDE; | 344 const CompletionCallback& callback) override; |
| 345 | 345 |
| 346 // Blocks read results on the socket. Reads will not complete until | 346 // Blocks read results on the socket. Reads will not complete until |
| 347 // UnblockReadResult() has been called and a result is ready from the | 347 // UnblockReadResult() has been called and a result is ready from the |
| 348 // underlying transport. Note: if BlockReadResult() is called while there is a | 348 // underlying transport. Note: if BlockReadResult() is called while there is a |
| 349 // hanging asynchronous Read(), that Read is blocked. | 349 // hanging asynchronous Read(), that Read is blocked. |
| 350 void BlockReadResult(); | 350 void BlockReadResult(); |
| 351 void UnblockReadResult(); | 351 void UnblockReadResult(); |
| 352 | 352 |
| 353 // Waits for the blocked Read() call to be complete at the underlying | 353 // Waits for the blocked Read() call to be complete at the underlying |
| 354 // transport. | 354 // transport. |
| (...skipping 192 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 547 public: | 547 public: |
| 548 explicit CountingStreamSocket(scoped_ptr<StreamSocket> transport) | 548 explicit CountingStreamSocket(scoped_ptr<StreamSocket> transport) |
| 549 : WrappedStreamSocket(transport.Pass()), | 549 : WrappedStreamSocket(transport.Pass()), |
| 550 read_count_(0), | 550 read_count_(0), |
| 551 write_count_(0) {} | 551 write_count_(0) {} |
| 552 virtual ~CountingStreamSocket() {} | 552 virtual ~CountingStreamSocket() {} |
| 553 | 553 |
| 554 // Socket implementation: | 554 // Socket implementation: |
| 555 virtual int Read(IOBuffer* buf, | 555 virtual int Read(IOBuffer* buf, |
| 556 int buf_len, | 556 int buf_len, |
| 557 const CompletionCallback& callback) OVERRIDE { | 557 const CompletionCallback& callback) override { |
| 558 read_count_++; | 558 read_count_++; |
| 559 return transport_->Read(buf, buf_len, callback); | 559 return transport_->Read(buf, buf_len, callback); |
| 560 } | 560 } |
| 561 virtual int Write(IOBuffer* buf, | 561 virtual int Write(IOBuffer* buf, |
| 562 int buf_len, | 562 int buf_len, |
| 563 const CompletionCallback& callback) OVERRIDE { | 563 const CompletionCallback& callback) override { |
| 564 write_count_++; | 564 write_count_++; |
| 565 return transport_->Write(buf, buf_len, callback); | 565 return transport_->Write(buf, buf_len, callback); |
| 566 } | 566 } |
| 567 | 567 |
| 568 int read_count() const { return read_count_; } | 568 int read_count() const { return read_count_; } |
| 569 int write_count() const { return write_count_; } | 569 int write_count() const { return write_count_; } |
| 570 | 570 |
| 571 private: | 571 private: |
| 572 int read_count_; | 572 int read_count_; |
| 573 int write_count_; | 573 int write_count_; |
| (...skipping 28 matching lines...) Expand all Loading... |
| 602 DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); | 602 DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); |
| 603 }; | 603 }; |
| 604 | 604 |
| 605 // A ChannelIDStore that always returns an error when asked for a | 605 // A ChannelIDStore that always returns an error when asked for a |
| 606 // channel id. | 606 // channel id. |
| 607 class FailingChannelIDStore : public ChannelIDStore { | 607 class FailingChannelIDStore : public ChannelIDStore { |
| 608 virtual int GetChannelID(const std::string& server_identifier, | 608 virtual int GetChannelID(const std::string& server_identifier, |
| 609 base::Time* expiration_time, | 609 base::Time* expiration_time, |
| 610 std::string* private_key_result, | 610 std::string* private_key_result, |
| 611 std::string* cert_result, | 611 std::string* cert_result, |
| 612 const GetChannelIDCallback& callback) OVERRIDE { | 612 const GetChannelIDCallback& callback) override { |
| 613 return ERR_UNEXPECTED; | 613 return ERR_UNEXPECTED; |
| 614 } | 614 } |
| 615 virtual void SetChannelID(const std::string& server_identifier, | 615 virtual void SetChannelID(const std::string& server_identifier, |
| 616 base::Time creation_time, | 616 base::Time creation_time, |
| 617 base::Time expiration_time, | 617 base::Time expiration_time, |
| 618 const std::string& private_key, | 618 const std::string& private_key, |
| 619 const std::string& cert) OVERRIDE {} | 619 const std::string& cert) override {} |
| 620 virtual void DeleteChannelID(const std::string& server_identifier, | 620 virtual void DeleteChannelID(const std::string& server_identifier, |
| 621 const base::Closure& completion_callback) | 621 const base::Closure& completion_callback) |
| 622 OVERRIDE {} | 622 override {} |
| 623 virtual void DeleteAllCreatedBetween(base::Time delete_begin, | 623 virtual void DeleteAllCreatedBetween(base::Time delete_begin, |
| 624 base::Time delete_end, | 624 base::Time delete_end, |
| 625 const base::Closure& completion_callback) | 625 const base::Closure& completion_callback) |
| 626 OVERRIDE {} | 626 override {} |
| 627 virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} | 627 virtual void DeleteAll(const base::Closure& completion_callback) override {} |
| 628 virtual void GetAllChannelIDs(const GetChannelIDListCallback& callback) | 628 virtual void GetAllChannelIDs(const GetChannelIDListCallback& callback) |
| 629 OVERRIDE {} | 629 override {} |
| 630 virtual int GetChannelIDCount() OVERRIDE { return 0; } | 630 virtual int GetChannelIDCount() override { return 0; } |
| 631 virtual void SetForceKeepSessionState() OVERRIDE {} | 631 virtual void SetForceKeepSessionState() override {} |
| 632 }; | 632 }; |
| 633 | 633 |
| 634 // A ChannelIDStore that asynchronously returns an error when asked for a | 634 // A ChannelIDStore that asynchronously returns an error when asked for a |
| 635 // channel id. | 635 // channel id. |
| 636 class AsyncFailingChannelIDStore : public ChannelIDStore { | 636 class AsyncFailingChannelIDStore : public ChannelIDStore { |
| 637 virtual int GetChannelID(const std::string& server_identifier, | 637 virtual int GetChannelID(const std::string& server_identifier, |
| 638 base::Time* expiration_time, | 638 base::Time* expiration_time, |
| 639 std::string* private_key_result, | 639 std::string* private_key_result, |
| 640 std::string* cert_result, | 640 std::string* cert_result, |
| 641 const GetChannelIDCallback& callback) OVERRIDE { | 641 const GetChannelIDCallback& callback) override { |
| 642 base::MessageLoop::current()->PostTask( | 642 base::MessageLoop::current()->PostTask( |
| 643 FROM_HERE, base::Bind(callback, ERR_UNEXPECTED, | 643 FROM_HERE, base::Bind(callback, ERR_UNEXPECTED, |
| 644 server_identifier, base::Time(), "", "")); | 644 server_identifier, base::Time(), "", "")); |
| 645 return ERR_IO_PENDING; | 645 return ERR_IO_PENDING; |
| 646 } | 646 } |
| 647 virtual void SetChannelID(const std::string& server_identifier, | 647 virtual void SetChannelID(const std::string& server_identifier, |
| 648 base::Time creation_time, | 648 base::Time creation_time, |
| 649 base::Time expiration_time, | 649 base::Time expiration_time, |
| 650 const std::string& private_key, | 650 const std::string& private_key, |
| 651 const std::string& cert) OVERRIDE {} | 651 const std::string& cert) override {} |
| 652 virtual void DeleteChannelID(const std::string& server_identifier, | 652 virtual void DeleteChannelID(const std::string& server_identifier, |
| 653 const base::Closure& completion_callback) | 653 const base::Closure& completion_callback) |
| 654 OVERRIDE {} | 654 override {} |
| 655 virtual void DeleteAllCreatedBetween(base::Time delete_begin, | 655 virtual void DeleteAllCreatedBetween(base::Time delete_begin, |
| 656 base::Time delete_end, | 656 base::Time delete_end, |
| 657 const base::Closure& completion_callback) | 657 const base::Closure& completion_callback) |
| 658 OVERRIDE {} | 658 override {} |
| 659 virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} | 659 virtual void DeleteAll(const base::Closure& completion_callback) override {} |
| 660 virtual void GetAllChannelIDs(const GetChannelIDListCallback& callback) | 660 virtual void GetAllChannelIDs(const GetChannelIDListCallback& callback) |
| 661 OVERRIDE {} | 661 override {} |
| 662 virtual int GetChannelIDCount() OVERRIDE { return 0; } | 662 virtual int GetChannelIDCount() override { return 0; } |
| 663 virtual void SetForceKeepSessionState() OVERRIDE {} | 663 virtual void SetForceKeepSessionState() override {} |
| 664 }; | 664 }; |
| 665 | 665 |
| 666 // A mock CTVerifier that records every call to Verify but doesn't verify | 666 // A mock CTVerifier that records every call to Verify but doesn't verify |
| 667 // anything. | 667 // anything. |
| 668 class MockCTVerifier : public CTVerifier { | 668 class MockCTVerifier : public CTVerifier { |
| 669 public: | 669 public: |
| 670 MOCK_METHOD5(Verify, int(X509Certificate*, | 670 MOCK_METHOD5(Verify, int(X509Certificate*, |
| 671 const std::string&, | 671 const std::string&, |
| 672 const std::string&, | 672 const std::string&, |
| 673 ct::CTVerifyResult*, | 673 ct::CTVerifyResult*, |
| (...skipping 2309 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 2983 ssl_config.channel_id_enabled = true; | 2983 ssl_config.channel_id_enabled = true; |
| 2984 | 2984 |
| 2985 int rv; | 2985 int rv; |
| 2986 ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); | 2986 ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); |
| 2987 | 2987 |
| 2988 EXPECT_EQ(ERR_UNEXPECTED, rv); | 2988 EXPECT_EQ(ERR_UNEXPECTED, rv); |
| 2989 EXPECT_FALSE(sock_->IsConnected()); | 2989 EXPECT_FALSE(sock_->IsConnected()); |
| 2990 } | 2990 } |
| 2991 | 2991 |
| 2992 } // namespace net | 2992 } // namespace net |
| OLD | NEW |