OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/socket/ssl_client_socket.h" | 5 #include "net/socket/ssl_client_socket.h" |
6 | 6 |
7 #include "base/callback_helpers.h" | 7 #include "base/callback_helpers.h" |
8 #include "base/memory/ref_counted.h" | 8 #include "base/memory/ref_counted.h" |
9 #include "base/run_loop.h" | 9 #include "base/run_loop.h" |
10 #include "base/time/time.h" | 10 #include "base/time/time.h" |
(...skipping 42 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
53 // This is to provide a common base class for subclasses to override specific | 53 // This is to provide a common base class for subclasses to override specific |
54 // StreamSocket methods for testing, while still communicating with a 'real' | 54 // StreamSocket methods for testing, while still communicating with a 'real' |
55 // StreamSocket. | 55 // StreamSocket. |
56 class WrappedStreamSocket : public StreamSocket { | 56 class WrappedStreamSocket : public StreamSocket { |
57 public: | 57 public: |
58 explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport) | 58 explicit WrappedStreamSocket(scoped_ptr<StreamSocket> transport) |
59 : transport_(transport.Pass()) {} | 59 : transport_(transport.Pass()) {} |
60 virtual ~WrappedStreamSocket() {} | 60 virtual ~WrappedStreamSocket() {} |
61 | 61 |
62 // StreamSocket implementation: | 62 // StreamSocket implementation: |
63 virtual int Connect(const CompletionCallback& callback) OVERRIDE { | 63 virtual int Connect(const CompletionCallback& callback) override { |
64 return transport_->Connect(callback); | 64 return transport_->Connect(callback); |
65 } | 65 } |
66 virtual void Disconnect() OVERRIDE { transport_->Disconnect(); } | 66 virtual void Disconnect() override { transport_->Disconnect(); } |
67 virtual bool IsConnected() const OVERRIDE { | 67 virtual bool IsConnected() const override { |
68 return transport_->IsConnected(); | 68 return transport_->IsConnected(); |
69 } | 69 } |
70 virtual bool IsConnectedAndIdle() const OVERRIDE { | 70 virtual bool IsConnectedAndIdle() const override { |
71 return transport_->IsConnectedAndIdle(); | 71 return transport_->IsConnectedAndIdle(); |
72 } | 72 } |
73 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { | 73 virtual int GetPeerAddress(IPEndPoint* address) const override { |
74 return transport_->GetPeerAddress(address); | 74 return transport_->GetPeerAddress(address); |
75 } | 75 } |
76 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { | 76 virtual int GetLocalAddress(IPEndPoint* address) const override { |
77 return transport_->GetLocalAddress(address); | 77 return transport_->GetLocalAddress(address); |
78 } | 78 } |
79 virtual const BoundNetLog& NetLog() const OVERRIDE { | 79 virtual const BoundNetLog& NetLog() const override { |
80 return transport_->NetLog(); | 80 return transport_->NetLog(); |
81 } | 81 } |
82 virtual void SetSubresourceSpeculation() OVERRIDE { | 82 virtual void SetSubresourceSpeculation() override { |
83 transport_->SetSubresourceSpeculation(); | 83 transport_->SetSubresourceSpeculation(); |
84 } | 84 } |
85 virtual void SetOmniboxSpeculation() OVERRIDE { | 85 virtual void SetOmniboxSpeculation() override { |
86 transport_->SetOmniboxSpeculation(); | 86 transport_->SetOmniboxSpeculation(); |
87 } | 87 } |
88 virtual bool WasEverUsed() const OVERRIDE { | 88 virtual bool WasEverUsed() const override { |
89 return transport_->WasEverUsed(); | 89 return transport_->WasEverUsed(); |
90 } | 90 } |
91 virtual bool UsingTCPFastOpen() const OVERRIDE { | 91 virtual bool UsingTCPFastOpen() const override { |
92 return transport_->UsingTCPFastOpen(); | 92 return transport_->UsingTCPFastOpen(); |
93 } | 93 } |
94 virtual bool WasNpnNegotiated() const OVERRIDE { | 94 virtual bool WasNpnNegotiated() const override { |
95 return transport_->WasNpnNegotiated(); | 95 return transport_->WasNpnNegotiated(); |
96 } | 96 } |
97 virtual NextProto GetNegotiatedProtocol() const OVERRIDE { | 97 virtual NextProto GetNegotiatedProtocol() const override { |
98 return transport_->GetNegotiatedProtocol(); | 98 return transport_->GetNegotiatedProtocol(); |
99 } | 99 } |
100 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { | 100 virtual bool GetSSLInfo(SSLInfo* ssl_info) override { |
101 return transport_->GetSSLInfo(ssl_info); | 101 return transport_->GetSSLInfo(ssl_info); |
102 } | 102 } |
103 | 103 |
104 // Socket implementation: | 104 // Socket implementation: |
105 virtual int Read(IOBuffer* buf, | 105 virtual int Read(IOBuffer* buf, |
106 int buf_len, | 106 int buf_len, |
107 const CompletionCallback& callback) OVERRIDE { | 107 const CompletionCallback& callback) override { |
108 return transport_->Read(buf, buf_len, callback); | 108 return transport_->Read(buf, buf_len, callback); |
109 } | 109 } |
110 virtual int Write(IOBuffer* buf, | 110 virtual int Write(IOBuffer* buf, |
111 int buf_len, | 111 int buf_len, |
112 const CompletionCallback& callback) OVERRIDE { | 112 const CompletionCallback& callback) override { |
113 return transport_->Write(buf, buf_len, callback); | 113 return transport_->Write(buf, buf_len, callback); |
114 } | 114 } |
115 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { | 115 virtual int SetReceiveBufferSize(int32 size) override { |
116 return transport_->SetReceiveBufferSize(size); | 116 return transport_->SetReceiveBufferSize(size); |
117 } | 117 } |
118 virtual int SetSendBufferSize(int32 size) OVERRIDE { | 118 virtual int SetSendBufferSize(int32 size) override { |
119 return transport_->SetSendBufferSize(size); | 119 return transport_->SetSendBufferSize(size); |
120 } | 120 } |
121 | 121 |
122 protected: | 122 protected: |
123 scoped_ptr<StreamSocket> transport_; | 123 scoped_ptr<StreamSocket> transport_; |
124 }; | 124 }; |
125 | 125 |
126 // ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that | 126 // ReadBufferingStreamSocket is a wrapper for an existing StreamSocket that |
127 // will ensure a certain amount of data is internally buffered before | 127 // will ensure a certain amount of data is internally buffered before |
128 // satisfying a Read() request. It exists to mimic OS-level internal | 128 // satisfying a Read() request. It exists to mimic OS-level internal |
129 // buffering, but in a way to guarantee that X number of bytes will be | 129 // buffering, but in a way to guarantee that X number of bytes will be |
130 // returned to callers of Read(), regardless of how quickly the OS receives | 130 // returned to callers of Read(), regardless of how quickly the OS receives |
131 // them from the TestServer. | 131 // them from the TestServer. |
132 class ReadBufferingStreamSocket : public WrappedStreamSocket { | 132 class ReadBufferingStreamSocket : public WrappedStreamSocket { |
133 public: | 133 public: |
134 explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport); | 134 explicit ReadBufferingStreamSocket(scoped_ptr<StreamSocket> transport); |
135 virtual ~ReadBufferingStreamSocket() {} | 135 virtual ~ReadBufferingStreamSocket() {} |
136 | 136 |
137 // Socket implementation: | 137 // Socket implementation: |
138 virtual int Read(IOBuffer* buf, | 138 virtual int Read(IOBuffer* buf, |
139 int buf_len, | 139 int buf_len, |
140 const CompletionCallback& callback) OVERRIDE; | 140 const CompletionCallback& callback) override; |
141 | 141 |
142 // Sets the internal buffer to |size|. This must not be greater than | 142 // Sets the internal buffer to |size|. This must not be greater than |
143 // the largest value supplied to Read() - that is, it does not handle | 143 // the largest value supplied to Read() - that is, it does not handle |
144 // having "leftovers" at the end of Read(). | 144 // having "leftovers" at the end of Read(). |
145 // Each call to Read() will be prevented from completion until at least | 145 // Each call to Read() will be prevented from completion until at least |
146 // |size| data has been read. | 146 // |size| data has been read. |
147 // Set to 0 to turn off buffering, causing Read() to transparently | 147 // Set to 0 to turn off buffering, causing Read() to transparently |
148 // read via the underlying transport. | 148 // read via the underlying transport. |
149 void SetBufferSize(int size); | 149 void SetBufferSize(int size); |
150 | 150 |
(...skipping 109 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
260 | 260 |
261 // Simulates synchronously receiving an error during Read() or Write() | 261 // Simulates synchronously receiving an error during Read() or Write() |
262 class SynchronousErrorStreamSocket : public WrappedStreamSocket { | 262 class SynchronousErrorStreamSocket : public WrappedStreamSocket { |
263 public: | 263 public: |
264 explicit SynchronousErrorStreamSocket(scoped_ptr<StreamSocket> transport); | 264 explicit SynchronousErrorStreamSocket(scoped_ptr<StreamSocket> transport); |
265 virtual ~SynchronousErrorStreamSocket() {} | 265 virtual ~SynchronousErrorStreamSocket() {} |
266 | 266 |
267 // Socket implementation: | 267 // Socket implementation: |
268 virtual int Read(IOBuffer* buf, | 268 virtual int Read(IOBuffer* buf, |
269 int buf_len, | 269 int buf_len, |
270 const CompletionCallback& callback) OVERRIDE; | 270 const CompletionCallback& callback) override; |
271 virtual int Write(IOBuffer* buf, | 271 virtual int Write(IOBuffer* buf, |
272 int buf_len, | 272 int buf_len, |
273 const CompletionCallback& callback) OVERRIDE; | 273 const CompletionCallback& callback) override; |
274 | 274 |
275 // Sets the next Read() call and all future calls to return |error|. | 275 // Sets the next Read() call and all future calls to return |error|. |
276 // If there is already a pending asynchronous read, the configured error | 276 // If there is already a pending asynchronous read, the configured error |
277 // will not be returned until that asynchronous read has completed and Read() | 277 // will not be returned until that asynchronous read has completed and Read() |
278 // is called again. | 278 // is called again. |
279 void SetNextReadError(Error error) { | 279 void SetNextReadError(Error error) { |
280 DCHECK_GE(0, error); | 280 DCHECK_GE(0, error); |
281 have_read_error_ = true; | 281 have_read_error_ = true; |
282 pending_read_error_ = error; | 282 pending_read_error_ = error; |
283 } | 283 } |
(...skipping 47 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
331 // deterministic manner (e.g.: independent of the TestServer and the OS's | 331 // deterministic manner (e.g.: independent of the TestServer and the OS's |
332 // semantics). | 332 // semantics). |
333 class FakeBlockingStreamSocket : public WrappedStreamSocket { | 333 class FakeBlockingStreamSocket : public WrappedStreamSocket { |
334 public: | 334 public: |
335 explicit FakeBlockingStreamSocket(scoped_ptr<StreamSocket> transport); | 335 explicit FakeBlockingStreamSocket(scoped_ptr<StreamSocket> transport); |
336 virtual ~FakeBlockingStreamSocket() {} | 336 virtual ~FakeBlockingStreamSocket() {} |
337 | 337 |
338 // Socket implementation: | 338 // Socket implementation: |
339 virtual int Read(IOBuffer* buf, | 339 virtual int Read(IOBuffer* buf, |
340 int buf_len, | 340 int buf_len, |
341 const CompletionCallback& callback) OVERRIDE; | 341 const CompletionCallback& callback) override; |
342 virtual int Write(IOBuffer* buf, | 342 virtual int Write(IOBuffer* buf, |
343 int buf_len, | 343 int buf_len, |
344 const CompletionCallback& callback) OVERRIDE; | 344 const CompletionCallback& callback) override; |
345 | 345 |
346 // Blocks read results on the socket. Reads will not complete until | 346 // Blocks read results on the socket. Reads will not complete until |
347 // UnblockReadResult() has been called and a result is ready from the | 347 // UnblockReadResult() has been called and a result is ready from the |
348 // underlying transport. Note: if BlockReadResult() is called while there is a | 348 // underlying transport. Note: if BlockReadResult() is called while there is a |
349 // hanging asynchronous Read(), that Read is blocked. | 349 // hanging asynchronous Read(), that Read is blocked. |
350 void BlockReadResult(); | 350 void BlockReadResult(); |
351 void UnblockReadResult(); | 351 void UnblockReadResult(); |
352 | 352 |
353 // Waits for the blocked Read() call to be complete at the underlying | 353 // Waits for the blocked Read() call to be complete at the underlying |
354 // transport. | 354 // transport. |
(...skipping 192 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
547 public: | 547 public: |
548 explicit CountingStreamSocket(scoped_ptr<StreamSocket> transport) | 548 explicit CountingStreamSocket(scoped_ptr<StreamSocket> transport) |
549 : WrappedStreamSocket(transport.Pass()), | 549 : WrappedStreamSocket(transport.Pass()), |
550 read_count_(0), | 550 read_count_(0), |
551 write_count_(0) {} | 551 write_count_(0) {} |
552 virtual ~CountingStreamSocket() {} | 552 virtual ~CountingStreamSocket() {} |
553 | 553 |
554 // Socket implementation: | 554 // Socket implementation: |
555 virtual int Read(IOBuffer* buf, | 555 virtual int Read(IOBuffer* buf, |
556 int buf_len, | 556 int buf_len, |
557 const CompletionCallback& callback) OVERRIDE { | 557 const CompletionCallback& callback) override { |
558 read_count_++; | 558 read_count_++; |
559 return transport_->Read(buf, buf_len, callback); | 559 return transport_->Read(buf, buf_len, callback); |
560 } | 560 } |
561 virtual int Write(IOBuffer* buf, | 561 virtual int Write(IOBuffer* buf, |
562 int buf_len, | 562 int buf_len, |
563 const CompletionCallback& callback) OVERRIDE { | 563 const CompletionCallback& callback) override { |
564 write_count_++; | 564 write_count_++; |
565 return transport_->Write(buf, buf_len, callback); | 565 return transport_->Write(buf, buf_len, callback); |
566 } | 566 } |
567 | 567 |
568 int read_count() const { return read_count_; } | 568 int read_count() const { return read_count_; } |
569 int write_count() const { return write_count_; } | 569 int write_count() const { return write_count_; } |
570 | 570 |
571 private: | 571 private: |
572 int read_count_; | 572 int read_count_; |
573 int write_count_; | 573 int write_count_; |
(...skipping 28 matching lines...) Expand all Loading... |
602 DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); | 602 DISALLOW_COPY_AND_ASSIGN(DeleteSocketCallback); |
603 }; | 603 }; |
604 | 604 |
605 // A ChannelIDStore that always returns an error when asked for a | 605 // A ChannelIDStore that always returns an error when asked for a |
606 // channel id. | 606 // channel id. |
607 class FailingChannelIDStore : public ChannelIDStore { | 607 class FailingChannelIDStore : public ChannelIDStore { |
608 virtual int GetChannelID(const std::string& server_identifier, | 608 virtual int GetChannelID(const std::string& server_identifier, |
609 base::Time* expiration_time, | 609 base::Time* expiration_time, |
610 std::string* private_key_result, | 610 std::string* private_key_result, |
611 std::string* cert_result, | 611 std::string* cert_result, |
612 const GetChannelIDCallback& callback) OVERRIDE { | 612 const GetChannelIDCallback& callback) override { |
613 return ERR_UNEXPECTED; | 613 return ERR_UNEXPECTED; |
614 } | 614 } |
615 virtual void SetChannelID(const std::string& server_identifier, | 615 virtual void SetChannelID(const std::string& server_identifier, |
616 base::Time creation_time, | 616 base::Time creation_time, |
617 base::Time expiration_time, | 617 base::Time expiration_time, |
618 const std::string& private_key, | 618 const std::string& private_key, |
619 const std::string& cert) OVERRIDE {} | 619 const std::string& cert) override {} |
620 virtual void DeleteChannelID(const std::string& server_identifier, | 620 virtual void DeleteChannelID(const std::string& server_identifier, |
621 const base::Closure& completion_callback) | 621 const base::Closure& completion_callback) |
622 OVERRIDE {} | 622 override {} |
623 virtual void DeleteAllCreatedBetween(base::Time delete_begin, | 623 virtual void DeleteAllCreatedBetween(base::Time delete_begin, |
624 base::Time delete_end, | 624 base::Time delete_end, |
625 const base::Closure& completion_callback) | 625 const base::Closure& completion_callback) |
626 OVERRIDE {} | 626 override {} |
627 virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} | 627 virtual void DeleteAll(const base::Closure& completion_callback) override {} |
628 virtual void GetAllChannelIDs(const GetChannelIDListCallback& callback) | 628 virtual void GetAllChannelIDs(const GetChannelIDListCallback& callback) |
629 OVERRIDE {} | 629 override {} |
630 virtual int GetChannelIDCount() OVERRIDE { return 0; } | 630 virtual int GetChannelIDCount() override { return 0; } |
631 virtual void SetForceKeepSessionState() OVERRIDE {} | 631 virtual void SetForceKeepSessionState() override {} |
632 }; | 632 }; |
633 | 633 |
634 // A ChannelIDStore that asynchronously returns an error when asked for a | 634 // A ChannelIDStore that asynchronously returns an error when asked for a |
635 // channel id. | 635 // channel id. |
636 class AsyncFailingChannelIDStore : public ChannelIDStore { | 636 class AsyncFailingChannelIDStore : public ChannelIDStore { |
637 virtual int GetChannelID(const std::string& server_identifier, | 637 virtual int GetChannelID(const std::string& server_identifier, |
638 base::Time* expiration_time, | 638 base::Time* expiration_time, |
639 std::string* private_key_result, | 639 std::string* private_key_result, |
640 std::string* cert_result, | 640 std::string* cert_result, |
641 const GetChannelIDCallback& callback) OVERRIDE { | 641 const GetChannelIDCallback& callback) override { |
642 base::MessageLoop::current()->PostTask( | 642 base::MessageLoop::current()->PostTask( |
643 FROM_HERE, base::Bind(callback, ERR_UNEXPECTED, | 643 FROM_HERE, base::Bind(callback, ERR_UNEXPECTED, |
644 server_identifier, base::Time(), "", "")); | 644 server_identifier, base::Time(), "", "")); |
645 return ERR_IO_PENDING; | 645 return ERR_IO_PENDING; |
646 } | 646 } |
647 virtual void SetChannelID(const std::string& server_identifier, | 647 virtual void SetChannelID(const std::string& server_identifier, |
648 base::Time creation_time, | 648 base::Time creation_time, |
649 base::Time expiration_time, | 649 base::Time expiration_time, |
650 const std::string& private_key, | 650 const std::string& private_key, |
651 const std::string& cert) OVERRIDE {} | 651 const std::string& cert) override {} |
652 virtual void DeleteChannelID(const std::string& server_identifier, | 652 virtual void DeleteChannelID(const std::string& server_identifier, |
653 const base::Closure& completion_callback) | 653 const base::Closure& completion_callback) |
654 OVERRIDE {} | 654 override {} |
655 virtual void DeleteAllCreatedBetween(base::Time delete_begin, | 655 virtual void DeleteAllCreatedBetween(base::Time delete_begin, |
656 base::Time delete_end, | 656 base::Time delete_end, |
657 const base::Closure& completion_callback) | 657 const base::Closure& completion_callback) |
658 OVERRIDE {} | 658 override {} |
659 virtual void DeleteAll(const base::Closure& completion_callback) OVERRIDE {} | 659 virtual void DeleteAll(const base::Closure& completion_callback) override {} |
660 virtual void GetAllChannelIDs(const GetChannelIDListCallback& callback) | 660 virtual void GetAllChannelIDs(const GetChannelIDListCallback& callback) |
661 OVERRIDE {} | 661 override {} |
662 virtual int GetChannelIDCount() OVERRIDE { return 0; } | 662 virtual int GetChannelIDCount() override { return 0; } |
663 virtual void SetForceKeepSessionState() OVERRIDE {} | 663 virtual void SetForceKeepSessionState() override {} |
664 }; | 664 }; |
665 | 665 |
666 // A mock CTVerifier that records every call to Verify but doesn't verify | 666 // A mock CTVerifier that records every call to Verify but doesn't verify |
667 // anything. | 667 // anything. |
668 class MockCTVerifier : public CTVerifier { | 668 class MockCTVerifier : public CTVerifier { |
669 public: | 669 public: |
670 MOCK_METHOD5(Verify, int(X509Certificate*, | 670 MOCK_METHOD5(Verify, int(X509Certificate*, |
671 const std::string&, | 671 const std::string&, |
672 const std::string&, | 672 const std::string&, |
673 ct::CTVerifyResult*, | 673 ct::CTVerifyResult*, |
(...skipping 2309 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
2983 ssl_config.channel_id_enabled = true; | 2983 ssl_config.channel_id_enabled = true; |
2984 | 2984 |
2985 int rv; | 2985 int rv; |
2986 ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); | 2986 ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); |
2987 | 2987 |
2988 EXPECT_EQ(ERR_UNEXPECTED, rv); | 2988 EXPECT_EQ(ERR_UNEXPECTED, rv); |
2989 EXPECT_FALSE(sock_->IsConnected()); | 2989 EXPECT_FALSE(sock_->IsConnected()); |
2990 } | 2990 } |
2991 | 2991 |
2992 } // namespace net | 2992 } // namespace net |
OLD | NEW |