| OLD | NEW |
| 1 // Copyright 2016 The Chromium Authors. All rights reserved. | 1 // Copyright 2016 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "components/certificate_transparency/log_dns_client.h" | 5 #include "components/certificate_transparency/log_dns_client.h" |
| 6 | 6 |
| 7 #include <memory> | 7 #include <memory> |
| 8 #include <numeric> | 8 #include <numeric> |
| 9 #include <string> | 9 #include <string> |
| 10 #include <utility> | 10 #include <utility> |
| (...skipping 73 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 84 // Each node is 32 bytes, with each byte having a different value. | 84 // Each node is 32 bytes, with each byte having a different value. |
| 85 for (size_t j = 0; j < crypto::kSHA256Length; ++j) { | 85 for (size_t j = 0; j < crypto::kSHA256Length; ++j) { |
| 86 node[j] = static_cast<char>((-127 + i + j) % 128); | 86 node[j] = static_cast<char>((-127 + i + j) % 128); |
| 87 } | 87 } |
| 88 audit_proof[i].assign(std::move(node)); | 88 audit_proof[i].assign(std::move(node)); |
| 89 } | 89 } |
| 90 | 90 |
| 91 return audit_proof; | 91 return audit_proof; |
| 92 } | 92 } |
| 93 | 93 |
| 94 // MockAuditProofCallback can be used as an AuditProofCallback. | 94 // MockCallback can be used as a base::Callback. |
| 95 // It will record the arguments it is invoked with and provides a helpful | 95 // It will record the arguments it is invoked with, which can be examined by |
| 96 // method for pumping the message loop until it is invoked. | 96 // calling args() or arg<N>(). |
| 97 class MockAuditProofCallback { | 97 // It only expects to be called once, but can be reused by calling Reset(). |
| 98 // Example: |
| 99 // MockCallback<int> mock; |
| 100 // foo.RegisterCallback(mock.AsCallback()); |
| 101 // foo.DoSomething(); |
| 102 // mock.WaitUntilRun(TestTimeouts::action_max_timeout()); |
| 103 // ASSERT_TRUE(mock.called()); |
| 104 // ASSERT_EQ(123, mock.arg<0>()); |
| 105 template <typename... Args> |
| 106 class MockCallback { |
| 98 public: | 107 public: |
| 99 MockAuditProofCallback() : called_(false) {} | 108 MockCallback() : called_(false) {} |
| 100 | 109 |
| 110 // Returns true if the callback has been invoked. |
| 101 bool called() const { return called_; } | 111 bool called() const { return called_; } |
| 102 net::Error result() const { return result_; } | |
| 103 const net::ct::MerkleAuditProof* proof() const { return proof_.get(); } | |
| 104 | 112 |
| 105 // Get this callback as an AuditProofCallback. | 113 // The arguments that the callback was called with. |
| 106 LogDnsClient::AuditProofCallback AsCallback() { | 114 const std::tuple<Args...>& args() const { |
| 107 return base::Bind(&MockAuditProofCallback::Run, base::Unretained(this)); | 115 DCHECK(called_); |
| 116 return args_; |
| 117 } |
| 118 |
| 119 // Gets a particular argument that the callback was invoked with. |
| 120 // For example, to get the first argument: mock_callback.arg<0>(); |
| 121 template <size_t N> |
| 122 const typename std::tuple_element<N, std::tuple<Args...>>::type& arg() const { |
| 123 DCHECK(called_); |
| 124 return std::get<N>(args_); |
| 125 } |
| 126 |
| 127 // Convert to a base::Callback. |
| 128 // TODO(robpercival): Could this reasonably be an implicit conversion? |
| 129 base::Callback<void(Args...)> AsCallback() { |
| 130 return base::Bind(&MockCallback::Run, base::Unretained(this)); |
| 108 } | 131 } |
| 109 | 132 |
| 110 // Wait until either the callback is invoked or the message loop goes idle | 133 // Wait until either the callback is invoked or the message loop goes idle |
| 111 // (after a specified |timeout|). Returns immediately if the callback has | 134 // (after a specified |timeout|). Returns immediately if the callback has |
| 112 // already been invoked. | 135 // already been invoked. |
| 113 void WaitUntilRun(base::TimeDelta timeout) { | 136 void WaitUntilRun(base::TimeDelta timeout) { |
| 114 if (called_) { | 137 if (called_) { |
| 115 return; | 138 return; |
| 116 } | 139 } |
| 117 | 140 |
| 118 // Pump the message loop until the the callback is invoked, which quits the | 141 // Pump the message loop until the the callback is invoked, which quits the |
| 119 // RunLoop, or a timeout expires and the message loop goes idle. | 142 // RunLoop, or a timeout expires and the message loop goes idle. |
| 120 run_loop_.reset(new base::RunLoop()); | 143 run_loop_.reset(new base::RunLoop()); |
| 121 base::Closure quit_closure = run_loop_->QuitWhenIdleClosure(); | 144 base::Closure quit_closure = run_loop_->QuitWhenIdleClosure(); |
| 122 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(FROM_HERE, | 145 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(FROM_HERE, |
| 123 quit_closure, timeout); | 146 quit_closure, timeout); |
| 124 run_loop_->Run(); | 147 run_loop_->Run(); |
| 125 run_loop_.reset(); | 148 run_loop_.reset(); |
| 126 } | 149 } |
| 127 | 150 |
| 151 void Reset() { |
| 152 called_ = false; |
| 153 args_ = std::tuple<Args...>(); |
| 154 } |
| 155 |
| 128 private: | 156 private: |
| 129 void Run(net::Error result, | 157 void Run(Args... args) { |
| 130 std::unique_ptr<net::ct::MerkleAuditProof> proof) { | |
| 131 EXPECT_FALSE(called_); | 158 EXPECT_FALSE(called_); |
| 132 called_ = true; | 159 called_ = true; |
| 133 result_ = result; | 160 args_ = std::make_tuple(std::forward<Args>(args)...); |
| 134 proof_ = std::move(proof); | |
| 135 if (run_loop_) { | 161 if (run_loop_) { |
| 136 run_loop_->Quit(); | 162 run_loop_->Quit(); |
| 137 } | 163 } |
| 138 } | 164 } |
| 139 | 165 |
| 140 // True if the callback has been invoked. | 166 // True if the callback has been invoked. |
| 141 bool called_; | 167 bool called_; |
| 142 // The arguments that the callback was invoked with. | 168 // The arguments that the callback was invoked with. |
| 143 net::Error result_; | 169 std::tuple<Args...> args_; |
| 144 std::unique_ptr<net::ct::MerkleAuditProof> proof_; | |
| 145 // The RunLoop currently being used to pump the message loop, as a means to | 170 // The RunLoop currently being used to pump the message loop, as a means to |
| 146 // execute this callback. | 171 // execute this callback. |
| 147 std::unique_ptr<base::RunLoop> run_loop_; | 172 std::unique_ptr<base::RunLoop> run_loop_; |
| 148 }; | 173 }; |
| 149 | 174 |
| 175 class MockAuditProofCallback |
| 176 : public MockCallback<net::Error, |
| 177 std::unique_ptr<net::ct::MerkleAuditProof>> { |
| 178 public: |
| 179 net::Error result() const { return arg<0>(); } |
| 180 const net::ct::MerkleAuditProof* proof() const { return arg<1>().get(); } |
| 181 }; |
| 182 |
| 183 class MockClosure : public MockCallback<> {}; |
| 184 |
| 150 class LogDnsClientTest : public ::testing::TestWithParam<net::IoMode> { | 185 class LogDnsClientTest : public ::testing::TestWithParam<net::IoMode> { |
| 151 protected: | 186 protected: |
| 152 LogDnsClientTest() | 187 LogDnsClientTest() |
| 153 : network_change_notifier_(net::NetworkChangeNotifier::CreateMock()) { | 188 : network_change_notifier_(net::NetworkChangeNotifier::CreateMock()) { |
| 154 mock_dns_.SetSocketReadMode(GetParam()); | 189 mock_dns_.SetSocketReadMode(GetParam()); |
| 155 mock_dns_.InitializeDnsConfig(); | 190 mock_dns_.InitializeDnsConfig(); |
| 156 } | 191 } |
| 157 | 192 |
| 158 std::unique_ptr<LogDnsClient> CreateLogDnsClient( | 193 std::unique_ptr<LogDnsClient> CreateLogDnsClient( |
| 159 size_t max_concurrent_queries) { | 194 size_t max_concurrent_queries) { |
| (...skipping 715 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
| 875 callback3.WaitUntilRun(TestTimeouts::action_max_timeout()); | 910 callback3.WaitUntilRun(TestTimeouts::action_max_timeout()); |
| 876 ASSERT_TRUE(callback3.called()); | 911 ASSERT_TRUE(callback3.called()); |
| 877 EXPECT_THAT(callback3.result(), IsOk()); | 912 EXPECT_THAT(callback3.result(), IsOk()); |
| 878 ASSERT_THAT(callback3.proof(), NotNull()); | 913 ASSERT_THAT(callback3.proof(), NotNull()); |
| 879 EXPECT_THAT(callback3.proof()->leaf_index, Eq(666u)); | 914 EXPECT_THAT(callback3.proof()->leaf_index, Eq(666u)); |
| 880 // TODO(robpercival): Enable this once MerkleAuditProof has tree_size. | 915 // TODO(robpercival): Enable this once MerkleAuditProof has tree_size. |
| 881 // EXPECT_THAT(callback3.proof()->tree_size, Eq(999999)); | 916 // EXPECT_THAT(callback3.proof()->tree_size, Eq(999999)); |
| 882 EXPECT_THAT(callback3.proof()->nodes, Eq(audit_proof)); | 917 EXPECT_THAT(callback3.proof()->nodes, Eq(audit_proof)); |
| 883 } | 918 } |
| 884 | 919 |
| 920 TEST_P(LogDnsClientTest, NotifiesWhenNoLongerThrottled) { |
| 921 const std::vector<std::string> audit_proof = GetSampleAuditProof(20); |
| 922 |
| 923 mock_dns_.ExpectLeafIndexRequestAndResponse(kLeafIndexQnames[0], 123456); |
| 924 mock_dns_.ExpectAuditProofRequestAndResponse("0.123456.999999.tree.ct.test.", |
| 925 audit_proof.begin(), |
| 926 audit_proof.begin() + 7); |
| 927 mock_dns_.ExpectAuditProofRequestAndResponse("7.123456.999999.tree.ct.test.", |
| 928 audit_proof.begin() + 7, |
| 929 audit_proof.begin() + 14); |
| 930 mock_dns_.ExpectAuditProofRequestAndResponse("14.123456.999999.tree.ct.test.", |
| 931 audit_proof.begin() + 14, |
| 932 audit_proof.end()); |
| 933 |
| 934 const size_t kMaxConcurrentQueries = 1; |
| 935 std::unique_ptr<LogDnsClient> log_client = |
| 936 CreateLogDnsClient(kMaxConcurrentQueries); |
| 937 |
| 938 // Start a query. |
| 939 MockAuditProofCallback proof_callback1; |
| 940 ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[0], 999999, |
| 941 proof_callback1.AsCallback()), |
| 942 IsError(net::ERR_IO_PENDING)); |
| 943 |
| 944 MockClosure not_throttled_callback; |
| 945 log_client->NotifyWhenNotThrottled(not_throttled_callback.AsCallback()); |
| 946 |
| 947 proof_callback1.WaitUntilRun(TestTimeouts::action_max_timeout()); |
| 948 ASSERT_TRUE(proof_callback1.called()); |
| 949 ASSERT_TRUE(not_throttled_callback.called()); |
| 950 |
| 951 // Start another query to check |not_throttled_callback| doesn't fire again. |
| 952 not_throttled_callback.Reset(); |
| 953 |
| 954 mock_dns_.ExpectLeafIndexRequestAndResponse(kLeafIndexQnames[1], 666); |
| 955 mock_dns_.ExpectAuditProofRequestAndResponse("0.666.999999.tree.ct.test.", |
| 956 audit_proof.begin(), |
| 957 audit_proof.begin() + 7); |
| 958 mock_dns_.ExpectAuditProofRequestAndResponse("7.666.999999.tree.ct.test.", |
| 959 audit_proof.begin() + 7, |
| 960 audit_proof.begin() + 14); |
| 961 mock_dns_.ExpectAuditProofRequestAndResponse("14.666.999999.tree.ct.test.", |
| 962 audit_proof.begin() + 14, |
| 963 audit_proof.end()); |
| 964 |
| 965 MockAuditProofCallback proof_callback2; |
| 966 ASSERT_THAT(log_client->QueryAuditProof("ct.test", kLeafHashes[1], 999999, |
| 967 proof_callback2.AsCallback()), |
| 968 IsError(net::ERR_IO_PENDING)); |
| 969 |
| 970 // Give the query a chance to run. |
| 971 proof_callback2.WaitUntilRun(TestTimeouts::action_max_timeout()); |
| 972 |
| 973 ASSERT_TRUE(proof_callback2.called()); |
| 974 ASSERT_FALSE(not_throttled_callback.called()); |
| 975 } |
| 976 |
| 885 INSTANTIATE_TEST_CASE_P(ReadMode, | 977 INSTANTIATE_TEST_CASE_P(ReadMode, |
| 886 LogDnsClientTest, | 978 LogDnsClientTest, |
| 887 ::testing::Values(net::IoMode::ASYNC, | 979 ::testing::Values(net::IoMode::ASYNC, |
| 888 net::IoMode::SYNCHRONOUS)); | 980 net::IoMode::SYNCHRONOUS)); |
| 889 | 981 |
| 890 } // namespace | 982 } // namespace |
| 891 } // namespace certificate_transparency | 983 } // namespace certificate_transparency |
| OLD | NEW |