Index: net/dns/dns_test_util.cc |
diff --git a/net/dns/dns_test_util.cc b/net/dns/dns_test_util.cc |
index 051f595e5a87df15c5b3806970690b9f06f03bdf..73d208c61519cd9d5f25f91e39cab01a564e9bb7 100644 |
--- a/net/dns/dns_test_util.cc |
+++ b/net/dns/dns_test_util.cc |
@@ -14,6 +14,7 @@ |
#include "net/base/dns_util.h" |
#include "net/base/io_buffer.h" |
#include "net/base/net_errors.h" |
+#include "net/dns/address_sorter.h" |
#include "net/dns/dns_client.h" |
#include "net/dns/dns_config_service.h" |
#include "net/dns/dns_protocol.h" |
@@ -25,19 +26,29 @@ |
namespace net { |
namespace { |
-// A DnsTransaction which responds with loopback to all queries starting with |
-// "ok", fails synchronously on all queries starting with "er", and NXDOMAIN to |
-// all others. |
+// A DnsTransaction which uses MockDnsClientRuleList to determine the response. |
class MockTransaction : public DnsTransaction, |
public base::SupportsWeakPtr<MockTransaction> { |
public: |
- MockTransaction(const std::string& hostname, |
+ MockTransaction(const MockDnsClientRuleList& rules, |
+ const std::string& hostname, |
uint16 qtype, |
const DnsTransactionFactory::CallbackType& callback) |
- : hostname_(hostname), |
+ : result_(MockDnsClientRule::FAIL_SYNC), |
+ hostname_(hostname), |
qtype_(qtype), |
callback_(callback), |
started_(false) { |
+ // Find the relevant rule which matches |qtype| and prefix of |hostname|. |
+ for (size_t i = 0; i < rules.size(); ++i) { |
+ const std::string& prefix = rules[i].prefix; |
+ if ((rules[i].qtype == qtype) && |
+ (hostname.size() >= prefix.size()) && |
+ (hostname.compare(0, prefix.size(), prefix) == 0)) { |
+ result_ = rules[i].result; |
+ break; |
+ } |
+ } |
} |
virtual const std::string& GetHostname() const OVERRIDE { |
@@ -51,7 +62,7 @@ class MockTransaction : public DnsTransaction, |
virtual int Start() OVERRIDE { |
EXPECT_FALSE(started_); |
started_ = true; |
- if (hostname_.substr(0, 2) == "er") |
+ if (MockDnsClientRule::FAIL_SYNC == result_) |
return ERR_NAME_NOT_RESOLVED; |
// Using WeakPtr to cleanly cancel when transaction is destroyed. |
MessageLoop::current()->PostTask( |
@@ -62,54 +73,66 @@ class MockTransaction : public DnsTransaction, |
private: |
void Finish() { |
- if (hostname_.substr(0, 2) == "ok") { |
- std::string qname; |
- DNSDomainFromDot(hostname_, &qname); |
- DnsQuery query(0, qname, qtype_); |
- |
- DnsResponse response; |
- char* buffer = response.io_buffer()->data(); |
- int nbytes = query.io_buffer()->size(); |
- memcpy(buffer, query.io_buffer()->data(), nbytes); |
- |
- const uint16 kPointerToQueryName = |
- static_cast<uint16>(0xc000 | sizeof(net::dns_protocol::Header)); |
- |
- const uint32 kTTL = 86400; // One day. |
- |
- // Size of RDATA which is a IPv4 or IPv6 address. |
- size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? |
- net::kIPv4AddressSize : net::kIPv6AddressSize; |
- |
- // 12 is the sum of sizes of the compressed name reference, TYPE, |
- // CLASS, TTL and RDLENGTH. |
- size_t answer_size = 12 + rdata_size; |
- |
- // Write answer with loopback IP address. |
- reinterpret_cast<dns_protocol::Header*>(buffer)->ancount = |
- base::HostToNet16(1); |
- BigEndianWriter writer(buffer + nbytes, answer_size); |
- writer.WriteU16(kPointerToQueryName); |
- writer.WriteU16(qtype_); |
- writer.WriteU16(net::dns_protocol::kClassIN); |
- writer.WriteU32(kTTL); |
- writer.WriteU16(rdata_size); |
- if (qtype_ == net::dns_protocol::kTypeA) { |
- char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; |
- writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); |
- } else { |
- char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, |
- 0, 0, 0, 0, 0, 0, 0, 1 }; |
- writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); |
- } |
- |
- EXPECT_TRUE(response.InitParse(nbytes + answer_size, query)); |
- callback_.Run(this, OK, &response); |
- } else { |
- callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); |
+ switch (result_) { |
+ case MockDnsClientRule::EMPTY: |
+ case MockDnsClientRule::OK: { |
+ std::string qname; |
+ DNSDomainFromDot(hostname_, &qname); |
+ DnsQuery query(0, qname, qtype_); |
+ |
+ DnsResponse response; |
+ char* buffer = response.io_buffer()->data(); |
+ int nbytes = query.io_buffer()->size(); |
+ memcpy(buffer, query.io_buffer()->data(), nbytes); |
+ dns_protocol::Header* header = |
+ reinterpret_cast<dns_protocol::Header*>(buffer); |
+ header->flags |= dns_protocol::kFlagResponse; |
+ |
+ if (MockDnsClientRule::OK == result_) { |
+ const uint16 kPointerToQueryName = |
+ static_cast<uint16>(0xc000 | sizeof(*header)); |
+ |
+ const uint32 kTTL = 86400; // One day. |
+ |
+ // Size of RDATA which is a IPv4 or IPv6 address. |
+ size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ? |
+ net::kIPv4AddressSize : net::kIPv6AddressSize; |
+ |
+ // 12 is the sum of sizes of the compressed name reference, TYPE, |
+ // CLASS, TTL and RDLENGTH. |
+ size_t answer_size = 12 + rdata_size; |
+ |
+ // Write answer with loopback IP address. |
+ header->ancount = base::HostToNet16(1); |
+ BigEndianWriter writer(buffer + nbytes, answer_size); |
+ writer.WriteU16(kPointerToQueryName); |
+ writer.WriteU16(qtype_); |
+ writer.WriteU16(net::dns_protocol::kClassIN); |
+ writer.WriteU32(kTTL); |
+ writer.WriteU16(rdata_size); |
+ if (qtype_ == net::dns_protocol::kTypeA) { |
+ char kIPv4Loopback[] = { 0x7f, 0, 0, 1 }; |
+ writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback)); |
+ } else { |
+ char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0, |
+ 0, 0, 0, 0, 0, 0, 0, 1 }; |
+ writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback)); |
+ } |
+ nbytes += answer_size; |
+ } |
+ EXPECT_TRUE(response.InitParse(nbytes, query)); |
+ callback_.Run(this, OK, &response); |
+ } break; |
+ case MockDnsClientRule::FAIL_ASYNC: |
+ callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL); |
+ break; |
+ default: |
+ NOTREACHED(); |
+ break; |
} |
} |
+ MockDnsClientRule::Result result_; |
const std::string hostname_; |
const uint16 qtype_; |
DnsTransactionFactory::CallbackType callback_; |
@@ -120,7 +143,8 @@ class MockTransaction : public DnsTransaction, |
// A DnsTransactionFactory which creates MockTransaction. |
class MockTransactionFactory : public DnsTransactionFactory { |
public: |
- MockTransactionFactory() {} |
+ explicit MockTransactionFactory(const MockDnsClientRuleList& rules) |
+ : rules_(rules) {} |
virtual ~MockTransactionFactory() {} |
virtual scoped_ptr<DnsTransaction> CreateTransaction( |
@@ -129,14 +153,29 @@ class MockTransactionFactory : public DnsTransactionFactory { |
const DnsTransactionFactory::CallbackType& callback, |
const BoundNetLog&) OVERRIDE { |
return scoped_ptr<DnsTransaction>( |
- new MockTransaction(hostname, qtype, callback)); |
+ new MockTransaction(rules_, hostname, qtype, callback)); |
+ } |
+ |
+ private: |
+ MockDnsClientRuleList rules_; |
+}; |
+ |
+class MockAddressSorter : public AddressSorter { |
+ public: |
+ virtual ~MockAddressSorter() {} |
+ virtual void Sort(const AddressList& list, |
+ const CallbackType& callback) const OVERRIDE { |
+ // Do nothing. |
+ callback.Run(true, list); |
} |
}; |
// MockDnsClient provides MockTransactionFactory. |
class MockDnsClient : public DnsClient { |
public: |
- explicit MockDnsClient(const DnsConfig& config) : config_(config) {} |
+ MockDnsClient(const DnsConfig& config, |
+ const MockDnsClientRuleList& rules) |
+ : config_(config), factory_(rules) {} |
virtual ~MockDnsClient() {} |
virtual void SetConfig(const DnsConfig& config) OVERRIDE { |
@@ -151,16 +190,22 @@ class MockDnsClient : public DnsClient { |
return config_.IsValid() ? &factory_ : NULL; |
} |
+ virtual AddressSorter* GetAddressSorter() OVERRIDE { |
+ return &address_sorter_; |
+ } |
+ |
private: |
DnsConfig config_; |
MockTransactionFactory factory_; |
+ MockAddressSorter address_sorter_; |
}; |
} // namespace |
// static |
-scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config) { |
- return scoped_ptr<DnsClient>(new MockDnsClient(config)); |
+scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config, |
+ const MockDnsClientRuleList& rules) { |
+ return scoped_ptr<DnsClient>(new MockDnsClient(config, rules)); |
} |
MockDnsConfigService::~MockDnsConfigService() { |