| OLD | NEW |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/dns/dns_transaction.h" | 5 #include "net/dns/dns_transaction.h" |
| 6 | 6 |
| 7 #include "base/bind.h" | 7 #include "base/bind.h" |
| 8 #include "base/memory/scoped_ptr.h" | 8 #include "base/memory/scoped_ptr.h" |
| 9 #include "base/memory/scoped_vector.h" | 9 #include "base/memory/scoped_vector.h" |
| 10 #include "base/rand_util.h" | 10 #include "base/rand_util.h" |
| (...skipping 28 matching lines...) Expand all Loading... |
| 39 uint16 qtype, | 39 uint16 qtype, |
| 40 IoMode mode, | 40 IoMode mode, |
| 41 bool use_tcp) | 41 bool use_tcp) |
| 42 : query_(new DnsQuery(id, DomainFromDot(dotted_name), qtype)), | 42 : query_(new DnsQuery(id, DomainFromDot(dotted_name), qtype)), |
| 43 use_tcp_(use_tcp) { | 43 use_tcp_(use_tcp) { |
| 44 if (use_tcp_) { | 44 if (use_tcp_) { |
| 45 scoped_ptr<uint16> length(new uint16); | 45 scoped_ptr<uint16> length(new uint16); |
| 46 *length = base::HostToNet16(query_->io_buffer()->size()); | 46 *length = base::HostToNet16(query_->io_buffer()->size()); |
| 47 writes_.push_back(MockWrite(mode, | 47 writes_.push_back(MockWrite(mode, |
| 48 reinterpret_cast<const char*>(length.get()), | 48 reinterpret_cast<const char*>(length.get()), |
| 49 sizeof(uint16))); | 49 sizeof(uint16), num_reads_and_writes())); |
| 50 lengths_.push_back(length.release()); | 50 lengths_.push_back(length.release()); |
| 51 } | 51 } |
| 52 writes_.push_back(MockWrite(mode, | 52 writes_.push_back(MockWrite(mode, query_->io_buffer()->data(), |
| 53 query_->io_buffer()->data(), | 53 query_->io_buffer()->size(), |
| 54 query_->io_buffer()->size())); | 54 num_reads_and_writes())); |
| 55 } | 55 } |
| 56 ~DnsSocketData() {} | 56 ~DnsSocketData() {} |
| 57 | 57 |
| 58 // All responses must be added before GetProvider. | 58 // All responses must be added before GetProvider. |
| 59 | 59 |
| 60 // Adds pre-built DnsResponse. |tcp_length| will be used in TCP mode only. | 60 // Adds pre-built DnsResponse. |tcp_length| will be used in TCP mode only. |
| 61 void AddResponseWithLength(scoped_ptr<DnsResponse> response, IoMode mode, | 61 void AddResponseWithLength(scoped_ptr<DnsResponse> response, IoMode mode, |
| 62 uint16 tcp_length) { | 62 uint16 tcp_length) { |
| 63 CHECK(!provider_.get()); | 63 CHECK(!provider_.get()); |
| 64 if (use_tcp_) { | 64 if (use_tcp_) { |
| 65 scoped_ptr<uint16> length(new uint16); | 65 scoped_ptr<uint16> length(new uint16); |
| 66 *length = base::HostToNet16(tcp_length); | 66 *length = base::HostToNet16(tcp_length); |
| 67 reads_.push_back(MockRead(mode, | 67 reads_.push_back(MockRead(mode, |
| 68 reinterpret_cast<const char*>(length.get()), | 68 reinterpret_cast<const char*>(length.get()), |
| 69 sizeof(uint16))); | 69 sizeof(uint16), num_reads_and_writes())); |
| 70 lengths_.push_back(length.release()); | 70 lengths_.push_back(length.release()); |
| 71 } | 71 } |
| 72 reads_.push_back(MockRead(mode, | 72 reads_.push_back(MockRead(mode, response->io_buffer()->data(), |
| 73 response->io_buffer()->data(), | 73 response->io_buffer()->size(), |
| 74 response->io_buffer()->size())); | 74 num_reads_and_writes())); |
| 75 responses_.push_back(response.release()); | 75 responses_.push_back(response.release()); |
| 76 } | 76 } |
| 77 | 77 |
| 78 // Adds pre-built DnsResponse. | 78 // Adds pre-built DnsResponse. |
| 79 void AddResponse(scoped_ptr<DnsResponse> response, IoMode mode) { | 79 void AddResponse(scoped_ptr<DnsResponse> response, IoMode mode) { |
| 80 uint16 tcp_length = response->io_buffer()->size(); | 80 uint16 tcp_length = response->io_buffer()->size(); |
| 81 AddResponseWithLength(response.Pass(), mode, tcp_length); | 81 AddResponseWithLength(response.Pass(), mode, tcp_length); |
| 82 } | 82 } |
| 83 | 83 |
| 84 // Adds pre-built response from |data| buffer. | 84 // Adds pre-built response from |data| buffer. |
| (...skipping 10 matching lines...) Expand all Loading... |
| 95 query_->io_buffer()->size(), | 95 query_->io_buffer()->size(), |
| 96 0)); | 96 0)); |
| 97 dns_protocol::Header* header = | 97 dns_protocol::Header* header = |
| 98 reinterpret_cast<dns_protocol::Header*>(response->io_buffer()->data()); | 98 reinterpret_cast<dns_protocol::Header*>(response->io_buffer()->data()); |
| 99 header->flags |= base::HostToNet16(dns_protocol::kFlagResponse | rcode); | 99 header->flags |= base::HostToNet16(dns_protocol::kFlagResponse | rcode); |
| 100 AddResponse(response.Pass(), mode); | 100 AddResponse(response.Pass(), mode); |
| 101 } | 101 } |
| 102 | 102 |
| 103 // Add error response. | 103 // Add error response. |
| 104 void AddReadError(int error, IoMode mode) { | 104 void AddReadError(int error, IoMode mode) { |
| 105 reads_.push_back(MockRead(mode, error)); | 105 reads_.push_back(MockRead(mode, error, num_reads_and_writes())); |
| 106 } | 106 } |
| 107 | 107 |
| 108 // Build, if needed, and return the SocketDataProvider. No new responses | 108 // Build, if needed, and return the SocketDataProvider. No new responses |
| 109 // should be added afterwards. | 109 // should be added afterwards. |
| 110 SocketDataProvider* GetProvider() { | 110 SequencedSocketData* GetProvider() { |
| 111 if (provider_.get()) | 111 if (provider_.get()) |
| 112 return provider_.get(); | 112 return provider_.get(); |
| 113 // Terminate the reads with ERR_IO_PENDING to prevent overrun and default to | 113 // Terminate the reads with ERR_IO_PENDING to prevent overrun and default to |
| 114 // timeout. | 114 // timeout. |
| 115 reads_.push_back(MockRead(ASYNC, ERR_IO_PENDING)); | 115 reads_.push_back( |
| 116 provider_.reset(new DelayedSocketData(1, &reads_[0], reads_.size(), | 116 MockRead(ASYNC, ERR_IO_PENDING, writes_.size() + reads_.size())); |
| 117 &writes_[0], writes_.size())); | 117 provider_.reset(new SequencedSocketData(&reads_[0], reads_.size(), |
| 118 &writes_[0], writes_.size())); |
| 118 if (use_tcp_) { | 119 if (use_tcp_) { |
| 119 provider_->set_connect_data(MockConnect(reads_[0].mode, OK)); | 120 provider_->set_connect_data(MockConnect(reads_[0].mode, OK)); |
| 120 } | 121 } |
| 121 return provider_.get(); | 122 return provider_.get(); |
| 122 } | 123 } |
| 123 | 124 |
| 124 uint16 query_id() const { | 125 uint16 query_id() const { |
| 125 return query_->id(); | 126 return query_->id(); |
| 126 } | 127 } |
| 127 | 128 |
| 128 // Returns true if the expected query was written to the socket. | 129 private: |
| 129 bool was_written() const { | 130 size_t num_reads_and_writes() const { return reads_.size() + writes_.size(); } |
| 130 CHECK(provider_.get()); | |
| 131 return provider_->write_index() > 0; | |
| 132 } | |
| 133 | 131 |
| 134 private: | |
| 135 scoped_ptr<DnsQuery> query_; | 132 scoped_ptr<DnsQuery> query_; |
| 136 bool use_tcp_; | 133 bool use_tcp_; |
| 137 ScopedVector<uint16> lengths_; | 134 ScopedVector<uint16> lengths_; |
| 138 ScopedVector<DnsResponse> responses_; | 135 ScopedVector<DnsResponse> responses_; |
| 139 std::vector<MockWrite> writes_; | 136 std::vector<MockWrite> writes_; |
| 140 std::vector<MockRead> reads_; | 137 std::vector<MockRead> reads_; |
| 141 scoped_ptr<DelayedSocketData> provider_; | 138 scoped_ptr<SequencedSocketData> provider_; |
| 142 | 139 |
| 143 DISALLOW_COPY_AND_ASSIGN(DnsSocketData); | 140 DISALLOW_COPY_AND_ASSIGN(DnsSocketData); |
| 144 }; | 141 }; |
| 145 | 142 |
| 146 class TestSocketFactory; | 143 class TestSocketFactory; |
| 147 | 144 |
| 148 // A variant of MockUDPClientSocket which always fails to Connect. | 145 // A variant of MockUDPClientSocket which always fails to Connect. |
| 149 class FailingUDPClientSocket : public MockUDPClientSocket { | 146 class FailingUDPClientSocket : public MockUDPClientSocket { |
| 150 public: | 147 public: |
| 151 FailingUDPClientSocket(SocketDataProvider* data, | 148 FailingUDPClientSocket(SocketDataProvider* data, |
| (...skipping 295 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 447 // and no retransmissions, | 444 // and no retransmissions, |
| 448 config_.attempts = 1; | 445 config_.attempts = 1; |
| 449 // but long enough timeout for memory tests. | 446 // but long enough timeout for memory tests. |
| 450 config_.timeout = TestTimeouts::action_timeout(); | 447 config_.timeout = TestTimeouts::action_timeout(); |
| 451 ConfigureFactory(); | 448 ConfigureFactory(); |
| 452 } | 449 } |
| 453 | 450 |
| 454 void TearDown() override { | 451 void TearDown() override { |
| 455 // Check that all socket data was at least written to. | 452 // Check that all socket data was at least written to. |
| 456 for (size_t i = 0; i < socket_data_.size(); ++i) { | 453 for (size_t i = 0; i < socket_data_.size(); ++i) { |
| 457 EXPECT_TRUE(socket_data_[i]->was_written()) << i; | 454 EXPECT_TRUE(socket_data_[i]->GetProvider()->AllWriteDataConsumed()) << i; |
| 458 } | 455 } |
| 459 } | 456 } |
| 460 | 457 |
| 461 protected: | 458 protected: |
| 462 int GetNextId(int min, int max) { | 459 int GetNextId(int min, int max) { |
| 463 EXPECT_FALSE(transaction_ids_.empty()); | 460 EXPECT_FALSE(transaction_ids_.empty()); |
| 464 int id = transaction_ids_.front(); | 461 int id = transaction_ids_.front(); |
| 465 transaction_ids_.pop_front(); | 462 transaction_ids_.pop_front(); |
| 466 EXPECT_GE(id, min); | 463 EXPECT_GE(id, min); |
| 467 EXPECT_LE(id, max); | 464 EXPECT_LE(id, max); |
| (...skipping 42 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 510 kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); | 507 kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); |
| 511 AddAsyncQueryAndResponse(1 /* id */, kT1HostName, kT1Qtype, | 508 AddAsyncQueryAndResponse(1 /* id */, kT1HostName, kT1Qtype, |
| 512 kT1ResponseDatagram, arraysize(kT1ResponseDatagram)); | 509 kT1ResponseDatagram, arraysize(kT1ResponseDatagram)); |
| 513 | 510 |
| 514 TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); | 511 TransactionHelper helper0(kT0HostName, kT0Qtype, kT0RecordCount); |
| 515 helper0.StartTransaction(transaction_factory_.get()); | 512 helper0.StartTransaction(transaction_factory_.get()); |
| 516 TransactionHelper helper1(kT1HostName, kT1Qtype, kT1RecordCount); | 513 TransactionHelper helper1(kT1HostName, kT1Qtype, kT1RecordCount); |
| 517 helper1.StartTransaction(transaction_factory_.get()); | 514 helper1.StartTransaction(transaction_factory_.get()); |
| 518 | 515 |
| 519 helper0.Cancel(); | 516 helper0.Cancel(); |
| 517 // Since the transaction has been cancelled, the assocaited socket has been |
| 518 // destroyed, so make sure the data provide does not attempt to callback |
| 519 // to the socket. |
| 520 // TODO(rch): Make the SocketDataProvider and MockSocket do this by default. |
| 521 socket_data_[0]->GetProvider()->set_socket(nullptr); |
| 520 | 522 |
| 521 base::MessageLoop::current()->RunUntilIdle(); | 523 base::MessageLoop::current()->RunUntilIdle(); |
| 522 | 524 |
| 523 EXPECT_FALSE(helper0.has_completed()); | 525 EXPECT_FALSE(helper0.has_completed()); |
| 524 EXPECT_TRUE(helper1.has_completed()); | 526 EXPECT_TRUE(helper1.has_completed()); |
| 525 } | 527 } |
| 526 | 528 |
| 527 TEST_F(DnsTransactionTest, DestroyFactory) { | 529 TEST_F(DnsTransactionTest, DestroyFactory) { |
| 528 AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype, | 530 AddAsyncQueryAndResponse(0 /* id */, kT0HostName, kT0Qtype, |
| 529 kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); | 531 kT0ResponseDatagram, arraysize(kT0ResponseDatagram)); |
| (...skipping 472 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 1002 config_.timeout = TestTimeouts::tiny_timeout(); | 1004 config_.timeout = TestTimeouts::tiny_timeout(); |
| 1003 ConfigureFactory(); | 1005 ConfigureFactory(); |
| 1004 | 1006 |
| 1005 TransactionHelper helper0(".", dns_protocol::kTypeA, ERR_INVALID_ARGUMENT); | 1007 TransactionHelper helper0(".", dns_protocol::kTypeA, ERR_INVALID_ARGUMENT); |
| 1006 EXPECT_TRUE(helper0.Run(transaction_factory_.get())); | 1008 EXPECT_TRUE(helper0.Run(transaction_factory_.get())); |
| 1007 } | 1009 } |
| 1008 | 1010 |
| 1009 } // namespace | 1011 } // namespace |
| 1010 | 1012 |
| 1011 } // namespace net | 1013 } // namespace net |
| OLD | NEW |