| Index: net/dns/dns_test_util.cc
|
| diff --git a/net/dns/dns_test_util.cc b/net/dns/dns_test_util.cc
|
| index 051f595e5a87df15c5b3806970690b9f06f03bdf..73d208c61519cd9d5f25f91e39cab01a564e9bb7 100644
|
| --- a/net/dns/dns_test_util.cc
|
| +++ b/net/dns/dns_test_util.cc
|
| @@ -14,6 +14,7 @@
|
| #include "net/base/dns_util.h"
|
| #include "net/base/io_buffer.h"
|
| #include "net/base/net_errors.h"
|
| +#include "net/dns/address_sorter.h"
|
| #include "net/dns/dns_client.h"
|
| #include "net/dns/dns_config_service.h"
|
| #include "net/dns/dns_protocol.h"
|
| @@ -25,19 +26,29 @@
|
| namespace net {
|
| namespace {
|
|
|
| -// A DnsTransaction which responds with loopback to all queries starting with
|
| -// "ok", fails synchronously on all queries starting with "er", and NXDOMAIN to
|
| -// all others.
|
| +// A DnsTransaction which uses MockDnsClientRuleList to determine the response.
|
| class MockTransaction : public DnsTransaction,
|
| public base::SupportsWeakPtr<MockTransaction> {
|
| public:
|
| - MockTransaction(const std::string& hostname,
|
| + MockTransaction(const MockDnsClientRuleList& rules,
|
| + const std::string& hostname,
|
| uint16 qtype,
|
| const DnsTransactionFactory::CallbackType& callback)
|
| - : hostname_(hostname),
|
| + : result_(MockDnsClientRule::FAIL_SYNC),
|
| + hostname_(hostname),
|
| qtype_(qtype),
|
| callback_(callback),
|
| started_(false) {
|
| + // Find the relevant rule which matches |qtype| and prefix of |hostname|.
|
| + for (size_t i = 0; i < rules.size(); ++i) {
|
| + const std::string& prefix = rules[i].prefix;
|
| + if ((rules[i].qtype == qtype) &&
|
| + (hostname.size() >= prefix.size()) &&
|
| + (hostname.compare(0, prefix.size(), prefix) == 0)) {
|
| + result_ = rules[i].result;
|
| + break;
|
| + }
|
| + }
|
| }
|
|
|
| virtual const std::string& GetHostname() const OVERRIDE {
|
| @@ -51,7 +62,7 @@ class MockTransaction : public DnsTransaction,
|
| virtual int Start() OVERRIDE {
|
| EXPECT_FALSE(started_);
|
| started_ = true;
|
| - if (hostname_.substr(0, 2) == "er")
|
| + if (MockDnsClientRule::FAIL_SYNC == result_)
|
| return ERR_NAME_NOT_RESOLVED;
|
| // Using WeakPtr to cleanly cancel when transaction is destroyed.
|
| MessageLoop::current()->PostTask(
|
| @@ -62,54 +73,66 @@ class MockTransaction : public DnsTransaction,
|
|
|
| private:
|
| void Finish() {
|
| - if (hostname_.substr(0, 2) == "ok") {
|
| - std::string qname;
|
| - DNSDomainFromDot(hostname_, &qname);
|
| - DnsQuery query(0, qname, qtype_);
|
| -
|
| - DnsResponse response;
|
| - char* buffer = response.io_buffer()->data();
|
| - int nbytes = query.io_buffer()->size();
|
| - memcpy(buffer, query.io_buffer()->data(), nbytes);
|
| -
|
| - const uint16 kPointerToQueryName =
|
| - static_cast<uint16>(0xc000 | sizeof(net::dns_protocol::Header));
|
| -
|
| - const uint32 kTTL = 86400; // One day.
|
| -
|
| - // Size of RDATA which is a IPv4 or IPv6 address.
|
| - size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ?
|
| - net::kIPv4AddressSize : net::kIPv6AddressSize;
|
| -
|
| - // 12 is the sum of sizes of the compressed name reference, TYPE,
|
| - // CLASS, TTL and RDLENGTH.
|
| - size_t answer_size = 12 + rdata_size;
|
| -
|
| - // Write answer with loopback IP address.
|
| - reinterpret_cast<dns_protocol::Header*>(buffer)->ancount =
|
| - base::HostToNet16(1);
|
| - BigEndianWriter writer(buffer + nbytes, answer_size);
|
| - writer.WriteU16(kPointerToQueryName);
|
| - writer.WriteU16(qtype_);
|
| - writer.WriteU16(net::dns_protocol::kClassIN);
|
| - writer.WriteU32(kTTL);
|
| - writer.WriteU16(rdata_size);
|
| - if (qtype_ == net::dns_protocol::kTypeA) {
|
| - char kIPv4Loopback[] = { 0x7f, 0, 0, 1 };
|
| - writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback));
|
| - } else {
|
| - char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0,
|
| - 0, 0, 0, 0, 0, 0, 0, 1 };
|
| - writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback));
|
| - }
|
| -
|
| - EXPECT_TRUE(response.InitParse(nbytes + answer_size, query));
|
| - callback_.Run(this, OK, &response);
|
| - } else {
|
| - callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL);
|
| + switch (result_) {
|
| + case MockDnsClientRule::EMPTY:
|
| + case MockDnsClientRule::OK: {
|
| + std::string qname;
|
| + DNSDomainFromDot(hostname_, &qname);
|
| + DnsQuery query(0, qname, qtype_);
|
| +
|
| + DnsResponse response;
|
| + char* buffer = response.io_buffer()->data();
|
| + int nbytes = query.io_buffer()->size();
|
| + memcpy(buffer, query.io_buffer()->data(), nbytes);
|
| + dns_protocol::Header* header =
|
| + reinterpret_cast<dns_protocol::Header*>(buffer);
|
| + header->flags |= dns_protocol::kFlagResponse;
|
| +
|
| + if (MockDnsClientRule::OK == result_) {
|
| + const uint16 kPointerToQueryName =
|
| + static_cast<uint16>(0xc000 | sizeof(*header));
|
| +
|
| + const uint32 kTTL = 86400; // One day.
|
| +
|
| + // Size of RDATA which is a IPv4 or IPv6 address.
|
| + size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ?
|
| + net::kIPv4AddressSize : net::kIPv6AddressSize;
|
| +
|
| + // 12 is the sum of sizes of the compressed name reference, TYPE,
|
| + // CLASS, TTL and RDLENGTH.
|
| + size_t answer_size = 12 + rdata_size;
|
| +
|
| + // Write answer with loopback IP address.
|
| + header->ancount = base::HostToNet16(1);
|
| + BigEndianWriter writer(buffer + nbytes, answer_size);
|
| + writer.WriteU16(kPointerToQueryName);
|
| + writer.WriteU16(qtype_);
|
| + writer.WriteU16(net::dns_protocol::kClassIN);
|
| + writer.WriteU32(kTTL);
|
| + writer.WriteU16(rdata_size);
|
| + if (qtype_ == net::dns_protocol::kTypeA) {
|
| + char kIPv4Loopback[] = { 0x7f, 0, 0, 1 };
|
| + writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback));
|
| + } else {
|
| + char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0,
|
| + 0, 0, 0, 0, 0, 0, 0, 1 };
|
| + writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback));
|
| + }
|
| + nbytes += answer_size;
|
| + }
|
| + EXPECT_TRUE(response.InitParse(nbytes, query));
|
| + callback_.Run(this, OK, &response);
|
| + } break;
|
| + case MockDnsClientRule::FAIL_ASYNC:
|
| + callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL);
|
| + break;
|
| + default:
|
| + NOTREACHED();
|
| + break;
|
| }
|
| }
|
|
|
| + MockDnsClientRule::Result result_;
|
| const std::string hostname_;
|
| const uint16 qtype_;
|
| DnsTransactionFactory::CallbackType callback_;
|
| @@ -120,7 +143,8 @@ class MockTransaction : public DnsTransaction,
|
| // A DnsTransactionFactory which creates MockTransaction.
|
| class MockTransactionFactory : public DnsTransactionFactory {
|
| public:
|
| - MockTransactionFactory() {}
|
| + explicit MockTransactionFactory(const MockDnsClientRuleList& rules)
|
| + : rules_(rules) {}
|
| virtual ~MockTransactionFactory() {}
|
|
|
| virtual scoped_ptr<DnsTransaction> CreateTransaction(
|
| @@ -129,14 +153,29 @@ class MockTransactionFactory : public DnsTransactionFactory {
|
| const DnsTransactionFactory::CallbackType& callback,
|
| const BoundNetLog&) OVERRIDE {
|
| return scoped_ptr<DnsTransaction>(
|
| - new MockTransaction(hostname, qtype, callback));
|
| + new MockTransaction(rules_, hostname, qtype, callback));
|
| + }
|
| +
|
| + private:
|
| + MockDnsClientRuleList rules_;
|
| +};
|
| +
|
| +class MockAddressSorter : public AddressSorter {
|
| + public:
|
| + virtual ~MockAddressSorter() {}
|
| + virtual void Sort(const AddressList& list,
|
| + const CallbackType& callback) const OVERRIDE {
|
| + // Do nothing.
|
| + callback.Run(true, list);
|
| }
|
| };
|
|
|
| // MockDnsClient provides MockTransactionFactory.
|
| class MockDnsClient : public DnsClient {
|
| public:
|
| - explicit MockDnsClient(const DnsConfig& config) : config_(config) {}
|
| + MockDnsClient(const DnsConfig& config,
|
| + const MockDnsClientRuleList& rules)
|
| + : config_(config), factory_(rules) {}
|
| virtual ~MockDnsClient() {}
|
|
|
| virtual void SetConfig(const DnsConfig& config) OVERRIDE {
|
| @@ -151,16 +190,22 @@ class MockDnsClient : public DnsClient {
|
| return config_.IsValid() ? &factory_ : NULL;
|
| }
|
|
|
| + virtual AddressSorter* GetAddressSorter() OVERRIDE {
|
| + return &address_sorter_;
|
| + }
|
| +
|
| private:
|
| DnsConfig config_;
|
| MockTransactionFactory factory_;
|
| + MockAddressSorter address_sorter_;
|
| };
|
|
|
| } // namespace
|
|
|
| // static
|
| -scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config) {
|
| - return scoped_ptr<DnsClient>(new MockDnsClient(config));
|
| +scoped_ptr<DnsClient> CreateMockDnsClient(const DnsConfig& config,
|
| + const MockDnsClientRuleList& rules) {
|
| + return scoped_ptr<DnsClient>(new MockDnsClient(config, rules));
|
| }
|
|
|
| MockDnsConfigService::~MockDnsConfigService() {
|
|
|