| OLD | NEW |
| (Empty) |
| 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | |
| 2 // Use of this source code is governed by a BSD-style license that can be | |
| 3 // found in the LICENSE file. | |
| 4 | |
| 5 #include "net/dns/dns_test_util.h" | |
| 6 | |
| 7 #include <string> | |
| 8 | |
| 9 #include "base/big_endian.h" | |
| 10 #include "base/bind.h" | |
| 11 #include "base/memory/weak_ptr.h" | |
| 12 #include "base/message_loop/message_loop.h" | |
| 13 #include "base/sys_byteorder.h" | |
| 14 #include "net/base/dns_util.h" | |
| 15 #include "net/base/io_buffer.h" | |
| 16 #include "net/base/net_errors.h" | |
| 17 #include "net/dns/address_sorter.h" | |
| 18 #include "net/dns/dns_query.h" | |
| 19 #include "net/dns/dns_response.h" | |
| 20 #include "net/dns/dns_transaction.h" | |
| 21 #include "testing/gtest/include/gtest/gtest.h" | |
| 22 | |
| 23 namespace net { | |
| 24 namespace { | |
| 25 | |
| 26 class MockAddressSorter : public AddressSorter { | |
| 27 public: | |
| 28 ~MockAddressSorter() override {} | |
| 29 void Sort(const AddressList& list, | |
| 30 const CallbackType& callback) const override { | |
| 31 // Do nothing. | |
| 32 callback.Run(true, list); | |
| 33 } | |
| 34 }; | |
| 35 | |
| 36 // A DnsTransaction which uses MockDnsClientRuleList to determine the response. | |
| 37 class MockTransaction : public DnsTransaction, | |
| 38 public base::SupportsWeakPtr<MockTransaction> { | |
| 39 public: | |
| 40 MockTransaction(const MockDnsClientRuleList& rules, | |
| 41 const std::string& hostname, | |
| 42 uint16 qtype, | |
| 43 const DnsTransactionFactory::CallbackType& callback) | |
| 44 : result_(MockDnsClientRule::FAIL), | |
| 45 hostname_(hostname), | |
| 46 qtype_(qtype), | |
| 47 callback_(callback), | |
| 48 started_(false), | |
| 49 delayed_(false) { | |
| 50 // Find the relevant rule which matches |qtype| and prefix of |hostname|. | |
| 51 for (size_t i = 0; i < rules.size(); ++i) { | |
| 52 const std::string& prefix = rules[i].prefix; | |
| 53 if ((rules[i].qtype == qtype) && | |
| 54 (hostname.size() >= prefix.size()) && | |
| 55 (hostname.compare(0, prefix.size(), prefix) == 0)) { | |
| 56 result_ = rules[i].result; | |
| 57 delayed_ = rules[i].delay; | |
| 58 break; | |
| 59 } | |
| 60 } | |
| 61 } | |
| 62 | |
| 63 virtual const std::string& GetHostname() const override { | |
| 64 return hostname_; | |
| 65 } | |
| 66 | |
| 67 virtual uint16 GetType() const override { | |
| 68 return qtype_; | |
| 69 } | |
| 70 | |
| 71 virtual void Start() override { | |
| 72 EXPECT_FALSE(started_); | |
| 73 started_ = true; | |
| 74 if (delayed_) | |
| 75 return; | |
| 76 // Using WeakPtr to cleanly cancel when transaction is destroyed. | |
| 77 base::MessageLoop::current()->PostTask( | |
| 78 FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr())); | |
| 79 } | |
| 80 | |
| 81 void FinishDelayedTransaction() { | |
| 82 EXPECT_TRUE(delayed_); | |
| 83 delayed_ = false; | |
| 84 Finish(); | |
| 85 } | |
| 86 | |
| 87 bool delayed() const { return delayed_; } | |
| 88 | |
| 89 private: | |
| 90 void Finish() { | |
| 91 switch (result_) { | |
| 92 case MockDnsClientRule::EMPTY: | |
| 93 case MockDnsClientRule::OK: { | |
| 94 std::string qname; | |
| 95 DNSDomainFromDot(hostname_, &qname); | |
| 96 DnsQuery query(0, qname, qtype_); | |
| 97 | |
| 98 DnsResponse response; | |
| 99 char* buffer = response.io_buffer()->data(); | |
| 100 int nbytes = query.io_buffer()->size(); | |
| 101 memcpy(buffer, query.io_buffer()->data(), nbytes); | |
| 102 dns_protocol::Header* header = | |
| 103 reinterpret_cast<dns_protocol::Header*>(buffer); | |
| 104 header->flags |= dns_protocol::kFlagResponse; | |
| 105 | |
| 106 if (MockDnsClientRule::OK == result_) { | |
| 107 const uint16 kPointerToQueryName = | |
| 108 static_cast<uint16>(0xc000 | sizeof(*header)); | |
| 109 | |
| 110 const uint32 kTTL = 86400; // One day. | |
| 111 | |
| 112 // Size of RDATA which is a IPv4 or IPv6 address. | |
| 113 size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? | |
| 114 net::kIPv4AddressSize : net::kIPv6AddressSize; | |
| 115 | |
| 116 // 12 is the sum of sizes of the compressed name reference, TYPE, | |
| 117 // CLASS, TTL and RDLENGTH. | |
| 118 size_t answer_size = 12 + rdata_size; | |
| 119 | |
| 120 // Write answer with loopback IP address. | |
| 121 header->ancount = base::HostToNet16(1); | |
| 122 base::BigEndianWriter writer(buffer + nbytes, answer_size); | |
| 123 writer.WriteU16(kPointerToQueryName); | |
| 124 writer.WriteU16(qtype_); | |
| 125 writer.WriteU16(net::dns_protocol::kClassIN); | |
| 126 writer.WriteU32(kTTL); | |
| 127 writer.WriteU16(rdata_size); | |
| 128 if (qtype_ == net::dns_protocol::kTypeA) { | |
| 129 char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; | |
| 130 writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); | |
| 131 } else { | |
| 132 char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, | |
| 133 0, 0, 0, 0, 0, 0, 0, 1 }; | |
| 134 writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); | |
| 135 } | |
| 136 nbytes += answer_size; | |
| 137 } | |
| 138 EXPECT_TRUE(response.InitParse(nbytes, query)); | |
| 139 callback_.Run(this, OK, &response); | |
| 140 } break; | |
| 141 case MockDnsClientRule::FAIL: | |
| 142 callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); | |
| 143 break; | |
| 144 case MockDnsClientRule::TIMEOUT: | |
| 145 callback_.Run(this, ERR_DNS_TIMED_OUT, NULL); | |
| 146 break; | |
| 147 default: | |
| 148 NOTREACHED(); | |
| 149 break; | |
| 150 } | |
| 151 } | |
| 152 | |
| 153 MockDnsClientRule::Result result_; | |
| 154 const std::string hostname_; | |
| 155 const uint16 qtype_; | |
| 156 DnsTransactionFactory::CallbackType callback_; | |
| 157 bool started_; | |
| 158 bool delayed_; | |
| 159 }; | |
| 160 | |
| 161 } // namespace | |
| 162 | |
| 163 // A DnsTransactionFactory which creates MockTransaction. | |
| 164 class MockTransactionFactory : public DnsTransactionFactory { | |
| 165 public: | |
| 166 explicit MockTransactionFactory(const MockDnsClientRuleList& rules) | |
| 167 : rules_(rules) {} | |
| 168 | |
| 169 ~MockTransactionFactory() override {} | |
| 170 | |
| 171 scoped_ptr<DnsTransaction> CreateTransaction( | |
| 172 const std::string& hostname, | |
| 173 uint16 qtype, | |
| 174 const DnsTransactionFactory::CallbackType& callback, | |
| 175 const BoundNetLog&) override { | |
| 176 MockTransaction* transaction = | |
| 177 new MockTransaction(rules_, hostname, qtype, callback); | |
| 178 if (transaction->delayed()) | |
| 179 delayed_transactions_.push_back(transaction->AsWeakPtr()); | |
| 180 return scoped_ptr<DnsTransaction>(transaction); | |
| 181 } | |
| 182 | |
| 183 void CompleteDelayedTransactions() { | |
| 184 DelayedTransactionList old_delayed_transactions; | |
| 185 old_delayed_transactions.swap(delayed_transactions_); | |
| 186 for (DelayedTransactionList::iterator it = old_delayed_transactions.begin(); | |
| 187 it != old_delayed_transactions.end(); ++it) { | |
| 188 if (it->get()) | |
| 189 (*it)->FinishDelayedTransaction(); | |
| 190 } | |
| 191 } | |
| 192 | |
| 193 private: | |
| 194 typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList; | |
| 195 | |
| 196 MockDnsClientRuleList rules_; | |
| 197 DelayedTransactionList delayed_transactions_; | |
| 198 }; | |
| 199 | |
| 200 MockDnsClient::MockDnsClient(const DnsConfig& config, | |
| 201 const MockDnsClientRuleList& rules) | |
| 202 : config_(config), | |
| 203 factory_(new MockTransactionFactory(rules)), | |
| 204 address_sorter_(new MockAddressSorter()) { | |
| 205 } | |
| 206 | |
| 207 MockDnsClient::~MockDnsClient() {} | |
| 208 | |
| 209 void MockDnsClient::SetConfig(const DnsConfig& config) { | |
| 210 config_ = config; | |
| 211 } | |
| 212 | |
| 213 const DnsConfig* MockDnsClient::GetConfig() const { | |
| 214 return config_.IsValid() ? &config_ : NULL; | |
| 215 } | |
| 216 | |
| 217 DnsTransactionFactory* MockDnsClient::GetTransactionFactory() { | |
| 218 return config_.IsValid() ? factory_.get() : NULL; | |
| 219 } | |
| 220 | |
| 221 AddressSorter* MockDnsClient::GetAddressSorter() { | |
| 222 return address_sorter_.get(); | |
| 223 } | |
| 224 | |
| 225 void MockDnsClient::CompleteDelayedTransactions() { | |
| 226 factory_->CompleteDelayedTransactions(); | |
| 227 } | |
| 228 | |
| 229 } // namespace net | |
| OLD | NEW |