OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "net/dns/dns_test_util.h" | 5 #include "net/dns/dns_test_util.h" |
6 | 6 |
7 #include <string> | 7 #include <string> |
8 | 8 |
9 #include "base/big_endian.h" | 9 #include "base/big_endian.h" |
10 #include "base/bind.h" | 10 #include "base/bind.h" |
| 11 #include "base/callback.h" |
11 #include "base/location.h" | 12 #include "base/location.h" |
12 #include "base/memory/weak_ptr.h" | 13 #include "base/memory/weak_ptr.h" |
13 #include "base/single_thread_task_runner.h" | 14 #include "base/single_thread_task_runner.h" |
14 #include "base/sys_byteorder.h" | 15 #include "base/sys_byteorder.h" |
15 #include "base/threading/thread_task_runner_handle.h" | 16 #include "base/threading/thread_task_runner_handle.h" |
16 #include "net/base/io_buffer.h" | 17 #include "net/base/io_buffer.h" |
17 #include "net/base/net_errors.h" | 18 #include "net/base/net_errors.h" |
18 #include "net/dns/address_sorter.h" | 19 #include "net/dns/address_sorter.h" |
19 #include "net/dns/dns_query.h" | 20 #include "net/dns/dns_query.h" |
20 #include "net/dns/dns_response.h" | 21 #include "net/dns/dns_response.h" |
(...skipping 15 matching lines...) Expand all Loading... |
36 }; | 37 }; |
37 | 38 |
38 // A DnsTransaction which uses MockDnsClientRuleList to determine the response. | 39 // A DnsTransaction which uses MockDnsClientRuleList to determine the response. |
39 class MockTransaction : public DnsTransaction, | 40 class MockTransaction : public DnsTransaction, |
40 public base::SupportsWeakPtr<MockTransaction> { | 41 public base::SupportsWeakPtr<MockTransaction> { |
41 public: | 42 public: |
42 MockTransaction(const MockDnsClientRuleList& rules, | 43 MockTransaction(const MockDnsClientRuleList& rules, |
43 const std::string& hostname, | 44 const std::string& hostname, |
44 uint16_t qtype, | 45 uint16_t qtype, |
45 const DnsTransactionFactory::CallbackType& callback) | 46 const DnsTransactionFactory::CallbackType& callback) |
46 : result_(MockDnsClientRule::FAIL), | 47 : hostname_(hostname), |
47 hostname_(hostname), | |
48 qtype_(qtype), | 48 qtype_(qtype), |
49 callback_(callback), | 49 callback_(callback), |
50 started_(false), | 50 started_(false), |
51 delayed_(false) { | 51 delayed_(false) { |
52 // Find the relevant rule which matches |qtype| and prefix of |hostname|. | 52 // Find the relevant rule which matches |qtype| and prefix of |hostname|. |
53 for (size_t i = 0; i < rules.size(); ++i) { | 53 for (size_t i = 0; i < rules.size(); ++i) { |
54 const std::string& prefix = rules[i].prefix; | 54 const std::string& prefix = rules[i].prefix; |
55 if ((rules[i].qtype == qtype) && | 55 if ((rules[i].qtype == qtype) && |
56 (hostname.size() >= prefix.size()) && | 56 (hostname.size() >= prefix.size()) && |
57 (hostname.compare(0, prefix.size(), prefix) == 0)) { | 57 (hostname.compare(0, prefix.size(), prefix) == 0)) { |
58 result_ = rules[i].result; | 58 response_callback_ = rules[i].response_callback; |
59 delayed_ = rules[i].delay; | 59 delayed_ = rules[i].delay; |
60 break; | 60 break; |
61 } | 61 } |
62 } | 62 } |
63 } | 63 } |
64 | 64 |
65 const std::string& GetHostname() const override { return hostname_; } | 65 const std::string& GetHostname() const override { return hostname_; } |
66 | 66 |
67 uint16_t GetType() const override { return qtype_; } | 67 uint16_t GetType() const override { return qtype_; } |
68 | 68 |
(...skipping 10 matching lines...) Expand all Loading... |
79 void FinishDelayedTransaction() { | 79 void FinishDelayedTransaction() { |
80 EXPECT_TRUE(delayed_); | 80 EXPECT_TRUE(delayed_); |
81 delayed_ = false; | 81 delayed_ = false; |
82 Finish(); | 82 Finish(); |
83 } | 83 } |
84 | 84 |
85 bool delayed() const { return delayed_; } | 85 bool delayed() const { return delayed_; } |
86 | 86 |
87 private: | 87 private: |
88 void Finish() { | 88 void Finish() { |
89 switch (result_) { | 89 if (response_callback_.is_null()) { |
90 case MockDnsClientRule::EMPTY: | 90 callback_.Run(this, ERR_NAME_NOT_RESOLVED, nullptr); |
91 case MockDnsClientRule::OK: { | 91 return; |
92 std::string qname; | 92 } |
93 DNSDomainFromDot(hostname_, &qname); | |
94 DnsQuery query(0, qname, qtype_); | |
95 | 93 |
96 DnsResponse response; | 94 std::string qname; |
97 char* buffer = response.io_buffer()->data(); | 95 DNSDomainFromDot(hostname_, &qname); |
98 int nbytes = query.io_buffer()->size(); | 96 DnsQuery query(0, qname, qtype_); |
99 memcpy(buffer, query.io_buffer()->data(), nbytes); | |
100 dns_protocol::Header* header = | |
101 reinterpret_cast<dns_protocol::Header*>(buffer); | |
102 header->flags |= dns_protocol::kFlagResponse; | |
103 | 97 |
104 if (MockDnsClientRule::OK == result_) { | 98 DnsResponse response; |
105 const uint16_t kPointerToQueryName = | 99 IOBufferWithSize* buffer = response.io_buffer(); |
106 static_cast<uint16_t>(0xc000 | sizeof(*header)); | 100 int query_size = query.io_buffer()->size(); |
| 101 CHECK_GE(buffer->size(), query_size); |
| 102 memcpy(buffer->data(), query.io_buffer()->data(), query_size); |
| 103 dns_protocol::Header* header = |
| 104 reinterpret_cast<dns_protocol::Header*>(buffer->data()); |
| 105 header->flags |= dns_protocol::kFlagResponse; |
107 | 106 |
108 const uint32_t kTTL = 86400; // One day. | 107 base::BigEndianWriter answer_writer(buffer->data() + query_size, |
109 | 108 buffer->size() - query_size); |
110 // Size of RDATA which is a IPv4 or IPv6 address. | 109 int net_error = response_callback_.Run(header, &answer_writer); |
111 size_t rdata_size = qtype_ == dns_protocol::kTypeA | 110 if (net_error == OK) { |
112 ? IPAddress::kIPv4AddressSize | 111 int nbytes = answer_writer.ptr() - buffer->data(); |
113 : IPAddress::kIPv6AddressSize; | 112 EXPECT_TRUE(response.InitParse(nbytes, query)); |
114 | 113 callback_.Run(this, OK, &response); |
115 // 12 is the sum of sizes of the compressed name reference, TYPE, | 114 } else { |
116 // CLASS, TTL and RDLENGTH. | 115 callback_.Run(this, net_error, nullptr); |
117 size_t answer_size = 12 + rdata_size; | |
118 | |
119 // Write answer with loopback IP address. | |
120 header->ancount = base::HostToNet16(1); | |
121 base::BigEndianWriter writer(buffer + nbytes, answer_size); | |
122 writer.WriteU16(kPointerToQueryName); | |
123 writer.WriteU16(qtype_); | |
124 writer.WriteU16(dns_protocol::kClassIN); | |
125 writer.WriteU32(kTTL); | |
126 writer.WriteU16(static_cast<uint16_t>(rdata_size)); | |
127 if (qtype_ == dns_protocol::kTypeA) { | |
128 char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; | |
129 writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); | |
130 } else { | |
131 char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, | |
132 0, 0, 0, 0, 0, 0, 0, 1 }; | |
133 writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); | |
134 } | |
135 nbytes += answer_size; | |
136 } | |
137 EXPECT_TRUE(response.InitParse(nbytes, query)); | |
138 callback_.Run(this, OK, &response); | |
139 } break; | |
140 case MockDnsClientRule::FAIL: | |
141 callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); | |
142 break; | |
143 case MockDnsClientRule::TIMEOUT: | |
144 callback_.Run(this, ERR_DNS_TIMED_OUT, NULL); | |
145 break; | |
146 default: | |
147 NOTREACHED(); | |
148 break; | |
149 } | 116 } |
150 } | 117 } |
151 | 118 |
152 MockDnsClientRule::Result result_; | 119 MockDnsClientRule::ResponseCallback response_callback_; |
153 const std::string hostname_; | 120 const std::string hostname_; |
154 const uint16_t qtype_; | 121 const uint16_t qtype_; |
155 DnsTransactionFactory::CallbackType callback_; | 122 DnsTransactionFactory::CallbackType callback_; |
156 bool started_; | 123 bool started_; |
157 bool delayed_; | 124 bool delayed_; |
158 }; | 125 }; |
159 | 126 |
| 127 // Simply returns the |net_error| argument. |
| 128 // Useful as a simple callback that does nothing but reports an error. |
| 129 int ReturnNetError(int net_error, |
| 130 dns_protocol::Header* response_header, |
| 131 base::BigEndianWriter* answer_writer) { |
| 132 CHECK_LE(net_error, 0); |
| 133 return net_error; |
| 134 } |
| 135 |
| 136 // Writes a |qtype| record for the loopback address using |answer_writer|. |
| 137 // |qtype| must be |dns_protocol::kTypeA| or |dns_protocol::kTypeAAAA|. |
| 138 // Returns net::OK if successful. |
| 139 int WriteLoopbackRecordResponse(uint16_t qtype, |
| 140 dns_protocol::Header* response_header, |
| 141 base::BigEndianWriter* answer_writer) { |
| 142 const uint16_t kPointerToQueryName = |
| 143 static_cast<uint16_t>(0xc000 | sizeof(*response_header)); |
| 144 |
| 145 const uint32_t kTTL = 86400; // One day. |
| 146 |
| 147 // Write answer with loopback IP address. |
| 148 response_header->ancount = |
| 149 base::HostToNet16(base::NetToHost16(response_header->ancount) + 1); |
| 150 CHECK(answer_writer->WriteU16(kPointerToQueryName)); |
| 151 CHECK(answer_writer->WriteU16(qtype)); |
| 152 CHECK(answer_writer->WriteU16(dns_protocol::kClassIN)); |
| 153 CHECK(answer_writer->WriteU32(kTTL)); |
| 154 if (qtype == dns_protocol::kTypeA) { |
| 155 char kIPv4Loopback[] = {0x7f, 0, 0, 1}; |
| 156 CHECK(answer_writer->WriteU16(IPAddress::kIPv4AddressSize)); |
| 157 CHECK(answer_writer->WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback))); |
| 158 } else if (qtype == dns_protocol::kTypeAAAA) { |
| 159 char kIPv6Loopback[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; |
| 160 CHECK(answer_writer->WriteU16(IPAddress::kIPv6AddressSize)); |
| 161 CHECK(answer_writer->WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback))); |
| 162 } else { |
| 163 NOTREACHED(); |
| 164 } |
| 165 |
| 166 return OK; |
| 167 } |
| 168 |
160 } // namespace | 169 } // namespace |
161 | 170 |
162 // A DnsTransactionFactory which creates MockTransaction. | 171 // A DnsTransactionFactory which creates MockTransaction. |
163 class MockTransactionFactory : public DnsTransactionFactory { | 172 class MockTransactionFactory : public DnsTransactionFactory { |
164 public: | 173 public: |
165 explicit MockTransactionFactory(const MockDnsClientRuleList& rules) | 174 explicit MockTransactionFactory(const MockDnsClientRuleList& rules) |
166 : rules_(rules) {} | 175 : rules_(rules) {} |
167 | 176 |
168 ~MockTransactionFactory() override {} | 177 ~MockTransactionFactory() override {} |
169 | 178 |
(...skipping 19 matching lines...) Expand all Loading... |
189 } | 198 } |
190 } | 199 } |
191 | 200 |
192 private: | 201 private: |
193 typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList; | 202 typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList; |
194 | 203 |
195 MockDnsClientRuleList rules_; | 204 MockDnsClientRuleList rules_; |
196 DelayedTransactionList delayed_transactions_; | 205 DelayedTransactionList delayed_transactions_; |
197 }; | 206 }; |
198 | 207 |
| 208 MockDnsClientRule::MockDnsClientRule(const std::string& prefix, |
| 209 uint16_t qtype, |
| 210 Result result, |
| 211 bool delay) |
| 212 : prefix(prefix), qtype(qtype), delay(delay) { |
| 213 switch (result) { |
| 214 case FAIL: |
| 215 response_callback = |
| 216 base::Bind(&ReturnNetError, net::ERR_NAME_NOT_RESOLVED); |
| 217 break; |
| 218 case TIMEOUT: |
| 219 response_callback = base::Bind(&ReturnNetError, net::ERR_DNS_TIMED_OUT); |
| 220 break; |
| 221 case EMPTY: |
| 222 response_callback = base::Bind(&ReturnNetError, net::OK); |
| 223 break; |
| 224 case OK: |
| 225 response_callback = base::Bind(&WriteLoopbackRecordResponse, qtype); |
| 226 break; |
| 227 } |
| 228 CHECK(!response_callback.is_null()); |
| 229 } |
| 230 |
| 231 MockDnsClientRule::MockDnsClientRule(const std::string& prefix, |
| 232 uint16_t qtype, |
| 233 ResponseCallback response_callback, |
| 234 bool delay) |
| 235 : response_callback(response_callback), |
| 236 prefix(prefix), |
| 237 qtype(qtype), |
| 238 delay(delay) {} |
| 239 |
| 240 MockDnsClientRule::MockDnsClientRule(const MockDnsClientRule& o) |
| 241 : response_callback(o.response_callback), |
| 242 prefix(o.prefix), |
| 243 qtype(o.qtype), |
| 244 delay(o.delay) {} |
| 245 |
| 246 MockDnsClientRule::MockDnsClientRule(MockDnsClientRule&& o) |
| 247 : response_callback(std::move(o.response_callback)), |
| 248 prefix(std::move(o.prefix)), |
| 249 qtype(o.qtype), |
| 250 delay(o.delay) {} |
| 251 |
| 252 MockDnsClientRule::~MockDnsClientRule() {} |
| 253 |
199 MockDnsClient::MockDnsClient(const DnsConfig& config, | 254 MockDnsClient::MockDnsClient(const DnsConfig& config, |
200 const MockDnsClientRuleList& rules) | 255 const MockDnsClientRuleList& rules) |
201 : config_(config), | 256 : config_(config), |
202 factory_(new MockTransactionFactory(rules)), | 257 factory_(new MockTransactionFactory(rules)), |
203 address_sorter_(new MockAddressSorter()) { | 258 address_sorter_(new MockAddressSorter()) { |
204 } | 259 } |
205 | 260 |
206 MockDnsClient::~MockDnsClient() {} | 261 MockDnsClient::~MockDnsClient() {} |
207 | 262 |
208 void MockDnsClient::SetConfig(const DnsConfig& config) { | 263 void MockDnsClient::SetConfig(const DnsConfig& config) { |
(...skipping 10 matching lines...) Expand all Loading... |
219 | 274 |
220 AddressSorter* MockDnsClient::GetAddressSorter() { | 275 AddressSorter* MockDnsClient::GetAddressSorter() { |
221 return address_sorter_.get(); | 276 return address_sorter_.get(); |
222 } | 277 } |
223 | 278 |
224 void MockDnsClient::CompleteDelayedTransactions() { | 279 void MockDnsClient::CompleteDelayedTransactions() { |
225 factory_->CompleteDelayedTransactions(); | 280 factory_->CompleteDelayedTransactions(); |
226 } | 281 } |
227 | 282 |
228 } // namespace net | 283 } // namespace net |
OLD | NEW |