| OLD | NEW |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "net/dns/dns_test_util.h" | 5 #include "net/dns/dns_test_util.h" |
| 6 | 6 |
| 7 #include <string> | 7 #include <string> |
| 8 | 8 |
| 9 #include "base/big_endian.h" | 9 #include "base/big_endian.h" |
| 10 #include "base/bind.h" | 10 #include "base/bind.h" |
| 11 #include "base/callback.h" |
| 11 #include "base/location.h" | 12 #include "base/location.h" |
| 12 #include "base/memory/weak_ptr.h" | 13 #include "base/memory/weak_ptr.h" |
| 13 #include "base/single_thread_task_runner.h" | 14 #include "base/single_thread_task_runner.h" |
| 14 #include "base/sys_byteorder.h" | 15 #include "base/sys_byteorder.h" |
| 15 #include "base/threading/thread_task_runner_handle.h" | 16 #include "base/threading/thread_task_runner_handle.h" |
| 16 #include "net/base/io_buffer.h" | 17 #include "net/base/io_buffer.h" |
| 17 #include "net/base/net_errors.h" | 18 #include "net/base/net_errors.h" |
| 18 #include "net/dns/address_sorter.h" | 19 #include "net/dns/address_sorter.h" |
| 19 #include "net/dns/dns_query.h" | 20 #include "net/dns/dns_query.h" |
| 20 #include "net/dns/dns_response.h" | 21 #include "net/dns/dns_response.h" |
| (...skipping 15 matching lines...) Expand all Loading... |
| 36 }; | 37 }; |
| 37 | 38 |
| 38 // A DnsTransaction which uses MockDnsClientRuleList to determine the response. | 39 // A DnsTransaction which uses MockDnsClientRuleList to determine the response. |
| 39 class MockTransaction : public DnsTransaction, | 40 class MockTransaction : public DnsTransaction, |
| 40 public base::SupportsWeakPtr<MockTransaction> { | 41 public base::SupportsWeakPtr<MockTransaction> { |
| 41 public: | 42 public: |
| 42 MockTransaction(const MockDnsClientRuleList& rules, | 43 MockTransaction(const MockDnsClientRuleList& rules, |
| 43 const std::string& hostname, | 44 const std::string& hostname, |
| 44 uint16_t qtype, | 45 uint16_t qtype, |
| 45 const DnsTransactionFactory::CallbackType& callback) | 46 const DnsTransactionFactory::CallbackType& callback) |
| 46 : result_(MockDnsClientRule::FAIL), | 47 : hostname_(hostname), |
| 47 hostname_(hostname), | |
| 48 qtype_(qtype), | 48 qtype_(qtype), |
| 49 callback_(callback), | 49 callback_(callback), |
| 50 started_(false), | 50 started_(false), |
| 51 delayed_(false) { | 51 delayed_(false) { |
| 52 // Find the relevant rule which matches |qtype| and prefix of |hostname|. | 52 // Find the relevant rule which matches |qtype| and prefix of |hostname|. |
| 53 for (size_t i = 0; i < rules.size(); ++i) { | 53 for (size_t i = 0; i < rules.size(); ++i) { |
| 54 const std::string& prefix = rules[i].prefix; | 54 const std::string& prefix = rules[i].prefix; |
| 55 if ((rules[i].qtype == qtype) && | 55 if ((rules[i].qtype == qtype) && |
| 56 (hostname.size() >= prefix.size()) && | 56 (hostname.size() >= prefix.size()) && |
| 57 (hostname.compare(0, prefix.size(), prefix) == 0)) { | 57 (hostname.compare(0, prefix.size(), prefix) == 0)) { |
| 58 result_ = rules[i].result; | 58 response_callback_ = rules[i].response_callback; |
| 59 delayed_ = rules[i].delay; | 59 delayed_ = rules[i].delay; |
| 60 break; | 60 break; |
| 61 } | 61 } |
| 62 } | 62 } |
| 63 } | 63 } |
| 64 | 64 |
| 65 const std::string& GetHostname() const override { return hostname_; } | 65 const std::string& GetHostname() const override { return hostname_; } |
| 66 | 66 |
| 67 uint16_t GetType() const override { return qtype_; } | 67 uint16_t GetType() const override { return qtype_; } |
| 68 | 68 |
| (...skipping 10 matching lines...) Expand all Loading... |
| 79 void FinishDelayedTransaction() { | 79 void FinishDelayedTransaction() { |
| 80 EXPECT_TRUE(delayed_); | 80 EXPECT_TRUE(delayed_); |
| 81 delayed_ = false; | 81 delayed_ = false; |
| 82 Finish(); | 82 Finish(); |
| 83 } | 83 } |
| 84 | 84 |
| 85 bool delayed() const { return delayed_; } | 85 bool delayed() const { return delayed_; } |
| 86 | 86 |
| 87 private: | 87 private: |
| 88 void Finish() { | 88 void Finish() { |
| 89 switch (result_) { | 89 if (response_callback_.is_null()) { |
| 90 case MockDnsClientRule::EMPTY: | 90 callback_.Run(this, ERR_NAME_NOT_RESOLVED, nullptr); |
| 91 case MockDnsClientRule::OK: { | 91 return; |
| 92 std::string qname; | 92 } |
| 93 DNSDomainFromDot(hostname_, &qname); | |
| 94 DnsQuery query(0, qname, qtype_); | |
| 95 | 93 |
| 96 DnsResponse response; | 94 std::string qname; |
| 97 char* buffer = response.io_buffer()->data(); | 95 DNSDomainFromDot(hostname_, &qname); |
| 98 int nbytes = query.io_buffer()->size(); | 96 DnsQuery query(0, qname, qtype_); |
| 99 memcpy(buffer, query.io_buffer()->data(), nbytes); | |
| 100 dns_protocol::Header* header = | |
| 101 reinterpret_cast<dns_protocol::Header*>(buffer); | |
| 102 header->flags |= dns_protocol::kFlagResponse; | |
| 103 | 97 |
| 104 if (MockDnsClientRule::OK == result_) { | 98 DnsResponse response; |
| 105 const uint16_t kPointerToQueryName = | 99 IOBufferWithSize* buffer = response.io_buffer(); |
| 106 static_cast<uint16_t>(0xc000 | sizeof(*header)); | 100 int query_size = query.io_buffer()->size(); |
| 101 CHECK_GE(buffer->size(), query_size); |
| 102 memcpy(buffer->data(), query.io_buffer()->data(), query_size); |
| 103 dns_protocol::Header* header = |
| 104 reinterpret_cast<dns_protocol::Header*>(buffer->data()); |
| 105 header->flags |= dns_protocol::kFlagResponse; |
| 107 | 106 |
| 108 const uint32_t kTTL = 86400; // One day. | 107 base::BigEndianWriter answer_writer(buffer->data() + query_size, |
| 109 | 108 buffer->size() - query_size); |
| 110 // Size of RDATA which is a IPv4 or IPv6 address. | 109 int net_error = response_callback_.Run(header, &answer_writer); |
| 111 size_t rdata_size = qtype_ == dns_protocol::kTypeA | 110 if (net_error == OK) { |
| 112 ? IPAddress::kIPv4AddressSize | 111 int nbytes = answer_writer.ptr() - buffer->data(); |
| 113 : IPAddress::kIPv6AddressSize; | 112 EXPECT_TRUE(response.InitParse(nbytes, query)); |
| 114 | 113 callback_.Run(this, OK, &response); |
| 115 // 12 is the sum of sizes of the compressed name reference, TYPE, | 114 } else { |
| 116 // CLASS, TTL and RDLENGTH. | 115 callback_.Run(this, net_error, nullptr); |
| 117 size_t answer_size = 12 + rdata_size; | |
| 118 | |
| 119 // Write answer with loopback IP address. | |
| 120 header->ancount = base::HostToNet16(1); | |
| 121 base::BigEndianWriter writer(buffer + nbytes, answer_size); | |
| 122 writer.WriteU16(kPointerToQueryName); | |
| 123 writer.WriteU16(qtype_); | |
| 124 writer.WriteU16(dns_protocol::kClassIN); | |
| 125 writer.WriteU32(kTTL); | |
| 126 writer.WriteU16(static_cast<uint16_t>(rdata_size)); | |
| 127 if (qtype_ == dns_protocol::kTypeA) { | |
| 128 char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; | |
| 129 writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); | |
| 130 } else { | |
| 131 char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, | |
| 132 0, 0, 0, 0, 0, 0, 0, 1 }; | |
| 133 writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); | |
| 134 } | |
| 135 nbytes += answer_size; | |
| 136 } | |
| 137 EXPECT_TRUE(response.InitParse(nbytes, query)); | |
| 138 callback_.Run(this, OK, &response); | |
| 139 } break; | |
| 140 case MockDnsClientRule::FAIL: | |
| 141 callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); | |
| 142 break; | |
| 143 case MockDnsClientRule::TIMEOUT: | |
| 144 callback_.Run(this, ERR_DNS_TIMED_OUT, NULL); | |
| 145 break; | |
| 146 default: | |
| 147 NOTREACHED(); | |
| 148 break; | |
| 149 } | 116 } |
| 150 } | 117 } |
| 151 | 118 |
| 152 MockDnsClientRule::Result result_; | 119 MockDnsClientRule::ResponseCallback response_callback_; |
| 153 const std::string hostname_; | 120 const std::string hostname_; |
| 154 const uint16_t qtype_; | 121 const uint16_t qtype_; |
| 155 DnsTransactionFactory::CallbackType callback_; | 122 DnsTransactionFactory::CallbackType callback_; |
| 156 bool started_; | 123 bool started_; |
| 157 bool delayed_; | 124 bool delayed_; |
| 158 }; | 125 }; |
| 159 | 126 |
| 127 // Simply returns the |net_error| argument. |
| 128 // Useful as a simple callback that does nothing but reports an error. |
| 129 int ReturnNetError(int net_error, |
| 130 dns_protocol::Header* response_header, |
| 131 base::BigEndianWriter* answer_writer) { |
| 132 CHECK_LE(net_error, 0); |
| 133 return net_error; |
| 134 } |
| 135 |
| 136 // Writes a |qtype| record for the loopback address using |answer_writer|. |
| 137 // |qtype| must be |dns_protocol::kTypeA| or |dns_protocol::kTypeAAAA|. |
| 138 // Returns net::OK if successful. |
| 139 int WriteLoopbackRecordResponse(uint16_t qtype, |
| 140 dns_protocol::Header* response_header, |
| 141 base::BigEndianWriter* answer_writer) { |
| 142 const uint16_t kPointerToQueryName = |
| 143 static_cast<uint16_t>(0xc000 | sizeof(*response_header)); |
| 144 |
| 145 const uint32_t kTTL = 86400; // One day. |
| 146 |
| 147 // Write answer with loopback IP address. |
| 148 response_header->ancount = |
| 149 base::HostToNet16(base::NetToHost16(response_header->ancount) + 1); |
| 150 CHECK(answer_writer->WriteU16(kPointerToQueryName)); |
| 151 CHECK(answer_writer->WriteU16(qtype)); |
| 152 CHECK(answer_writer->WriteU16(dns_protocol::kClassIN)); |
| 153 CHECK(answer_writer->WriteU32(kTTL)); |
| 154 if (qtype == dns_protocol::kTypeA) { |
| 155 char kIPv4Loopback[] = {0x7f, 0, 0, 1}; |
| 156 CHECK(answer_writer->WriteU16(IPAddress::kIPv4AddressSize)); |
| 157 CHECK(answer_writer->WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback))); |
| 158 } else if (qtype == dns_protocol::kTypeAAAA) { |
| 159 char kIPv6Loopback[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; |
| 160 CHECK(answer_writer->WriteU16(IPAddress::kIPv6AddressSize)); |
| 161 CHECK(answer_writer->WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback))); |
| 162 } else { |
| 163 NOTREACHED(); |
| 164 } |
| 165 |
| 166 return OK; |
| 167 } |
| 168 |
| 160 } // namespace | 169 } // namespace |
| 161 | 170 |
| 162 // A DnsTransactionFactory which creates MockTransaction. | 171 // A DnsTransactionFactory which creates MockTransaction. |
| 163 class MockTransactionFactory : public DnsTransactionFactory { | 172 class MockTransactionFactory : public DnsTransactionFactory { |
| 164 public: | 173 public: |
| 165 explicit MockTransactionFactory(const MockDnsClientRuleList& rules) | 174 explicit MockTransactionFactory(const MockDnsClientRuleList& rules) |
| 166 : rules_(rules) {} | 175 : rules_(rules) {} |
| 167 | 176 |
| 168 ~MockTransactionFactory() override {} | 177 ~MockTransactionFactory() override {} |
| 169 | 178 |
| (...skipping 19 matching lines...) Expand all Loading... |
| 189 } | 198 } |
| 190 } | 199 } |
| 191 | 200 |
| 192 private: | 201 private: |
| 193 typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList; | 202 typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList; |
| 194 | 203 |
| 195 MockDnsClientRuleList rules_; | 204 MockDnsClientRuleList rules_; |
| 196 DelayedTransactionList delayed_transactions_; | 205 DelayedTransactionList delayed_transactions_; |
| 197 }; | 206 }; |
| 198 | 207 |
| 208 MockDnsClientRule::MockDnsClientRule(const std::string& prefix, |
| 209 uint16_t qtype, |
| 210 Result result, |
| 211 bool delay) |
| 212 : prefix(prefix), qtype(qtype), delay(delay) { |
| 213 switch (result) { |
| 214 case FAIL: |
| 215 response_callback = |
| 216 base::Bind(&ReturnNetError, net::ERR_NAME_NOT_RESOLVED); |
| 217 break; |
| 218 case TIMEOUT: |
| 219 response_callback = base::Bind(&ReturnNetError, net::ERR_DNS_TIMED_OUT); |
| 220 break; |
| 221 case EMPTY: |
| 222 response_callback = base::Bind(&ReturnNetError, net::OK); |
| 223 break; |
| 224 case OK: |
| 225 response_callback = base::Bind(&WriteLoopbackRecordResponse, qtype); |
| 226 break; |
| 227 } |
| 228 CHECK(!response_callback.is_null()); |
| 229 } |
| 230 |
| 231 MockDnsClientRule::MockDnsClientRule(const std::string& prefix, |
| 232 uint16_t qtype, |
| 233 ResponseCallback response_callback, |
| 234 bool delay) |
| 235 : response_callback(response_callback), |
| 236 prefix(prefix), |
| 237 qtype(qtype), |
| 238 delay(delay) {} |
| 239 |
| 240 MockDnsClientRule::MockDnsClientRule(const MockDnsClientRule& o) |
| 241 : response_callback(o.response_callback), |
| 242 prefix(o.prefix), |
| 243 qtype(o.qtype), |
| 244 delay(o.delay) {} |
| 245 |
| 246 MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& o) |
| 247 : response_callback(std::move(o.response_callback)), |
| 248 prefix(std::move(o.prefix)), |
| 249 qtype(o.qtype), |
| 250 delay(o.delay) {} |
| 251 |
| 252 MockDnsClientRule::~MockDnsClientRule() {} |
| 253 |
| 199 MockDnsClient::MockDnsClient(const DnsConfig& config, | 254 MockDnsClient::MockDnsClient(const DnsConfig& config, |
| 200 const MockDnsClientRuleList& rules) | 255 const MockDnsClientRuleList& rules) |
| 201 : config_(config), | 256 : config_(config), |
| 202 factory_(new MockTransactionFactory(rules)), | 257 factory_(new MockTransactionFactory(rules)), |
| 203 address_sorter_(new MockAddressSorter()) { | 258 address_sorter_(new MockAddressSorter()) { |
| 204 } | 259 } |
| 205 | 260 |
| 206 MockDnsClient::~MockDnsClient() {} | 261 MockDnsClient::~MockDnsClient() {} |
| 207 | 262 |
| 208 void MockDnsClient::SetConfig(const DnsConfig& config) { | 263 void MockDnsClient::SetConfig(const DnsConfig& config) { |
| (...skipping 10 matching lines...) Expand all Loading... |
| 219 | 274 |
| 220 AddressSorter* MockDnsClient::GetAddressSorter() { | 275 AddressSorter* MockDnsClient::GetAddressSorter() { |
| 221 return address_sorter_.get(); | 276 return address_sorter_.get(); |
| 222 } | 277 } |
| 223 | 278 |
| 224 void MockDnsClient::CompleteDelayedTransactions() { | 279 void MockDnsClient::CompleteDelayedTransactions() { |
| 225 factory_->CompleteDelayedTransactions(); | 280 factory_->CompleteDelayedTransactions(); |
| 226 } | 281 } |
| 227 | 282 |
| 228 } // namespace net | 283 } // namespace net |
| OLD | NEW |