OLD | NEW |
1 // Copyright 2014 The Chromium Authors. All rights reserved. | 1 // Copyright 2014 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "base/at_exit.h" | 5 #include "base/at_exit.h" |
6 #include "base/macros.h" | 6 #include "base/macros.h" |
7 #include "base/memory/scoped_ptr.h" | 7 #include "base/memory/scoped_ptr.h" |
8 #include "mojo/public/cpp/bindings/callback.h" | 8 #include "mojo/public/cpp/bindings/callback.h" |
| 9 #include "mojo/services/public/cpp/network/udp_socket_wrapper.h" |
9 #include "mojo/services/public/interfaces/network/network_service.mojom.h" | 10 #include "mojo/services/public/interfaces/network/network_service.mojom.h" |
10 #include "mojo/services/public/interfaces/network/udp_socket.mojom.h" | 11 #include "mojo/services/public/interfaces/network/udp_socket.mojom.h" |
11 #include "mojo/shell/shell_test_helper.h" | 12 #include "mojo/shell/shell_test_helper.h" |
12 #include "net/base/net_errors.h" | 13 #include "net/base/net_errors.h" |
13 #include "testing/gtest/include/gtest/gtest.h" | 14 #include "testing/gtest/include/gtest/gtest.h" |
14 #include "url/gurl.h" | 15 #include "url/gurl.h" |
15 | 16 |
16 namespace mojo { | 17 namespace mojo { |
17 namespace service { | 18 namespace service { |
18 namespace { | 19 namespace { |
(...skipping 12 matching lines...) Expand all Loading... |
31 return addr.Pass(); | 32 return addr.Pass(); |
32 } | 33 } |
33 | 34 |
34 Array<uint8_t> CreateTestMessage(uint8_t initial, size_t size) { | 35 Array<uint8_t> CreateTestMessage(uint8_t initial, size_t size) { |
35 Array<uint8_t> array(size); | 36 Array<uint8_t> array(size); |
36 for (size_t i = 0; i < size; ++i) | 37 for (size_t i = 0; i < size; ++i) |
37 array[i] = static_cast<uint8_t>((i + initial) % 256); | 38 array[i] = static_cast<uint8_t>((i + initial) % 256); |
38 return array.Pass(); | 39 return array.Pass(); |
39 } | 40 } |
40 | 41 |
41 bool AreEqualArrays(const Array<uint8_t>& array_1, | |
42 const Array<uint8_t>& array_2) { | |
43 if (array_1.is_null() != array_2.is_null()) | |
44 return false; | |
45 else if (array_1.is_null()) | |
46 return true; | |
47 | |
48 if (array_1.size() != array_2.size()) | |
49 return false; | |
50 | |
51 for (size_t i = 0; i < array_1.size(); ++i) { | |
52 if (array_1[i] != array_2[i]) | |
53 return false; | |
54 } | |
55 | |
56 return true; | |
57 } | |
58 | |
59 template <typename CallbackType> | 42 template <typename CallbackType> |
60 class TestCallbackBase { | 43 class TestCallbackBase { |
61 public: | 44 public: |
62 TestCallbackBase() : state_(nullptr), run_loop_(nullptr), ran_(false) {} | 45 TestCallbackBase() : state_(nullptr), run_loop_(nullptr), ran_(false) {} |
63 | 46 |
64 ~TestCallbackBase() { | 47 ~TestCallbackBase() { |
65 state_->set_test_callback(nullptr); | 48 state_->set_test_callback(nullptr); |
66 } | 49 } |
67 | 50 |
68 CallbackType callback() const { return callback_; } | 51 CallbackType callback() const { return callback_; } |
69 | 52 |
70 void WaitForResult() { | 53 void WaitForResult() { |
71 if (ran_) | 54 if (ran_) |
72 return; | 55 return; |
73 | 56 |
74 base::RunLoop run_loop; | 57 base::RunLoop run_loop; |
75 run_loop_ = &run_loop; | 58 run_loop_ = &run_loop; |
76 run_loop.Run(); | 59 run_loop.Run(); |
77 run_loop_ = nullptr; | 60 run_loop_ = nullptr; |
78 } | 61 } |
79 | 62 |
80 protected: | 63 protected: |
81 struct StateBase : public CallbackType::Runnable { | 64 struct StateBase : public CallbackType::Runnable { |
82 StateBase() : test_callback_(nullptr) {} | 65 StateBase() : test_callback_(nullptr) {} |
83 virtual ~StateBase() {} | 66 ~StateBase() override {} |
84 | 67 |
85 void set_test_callback(TestCallbackBase* test_callback) { | 68 void set_test_callback(TestCallbackBase* test_callback) { |
86 test_callback_ = test_callback; | 69 test_callback_ = test_callback; |
87 } | 70 } |
88 | 71 |
89 protected: | 72 protected: |
90 void NotifyRun() const { | 73 void NotifyRun() const { |
91 if (test_callback_) { | 74 if (test_callback_) { |
92 test_callback_->ran_ = true; | 75 test_callback_->ran_ = true; |
93 if (test_callback_->run_loop_) | 76 if (test_callback_->run_loop_) |
(...skipping 30 matching lines...) Expand all Loading... |
124 public: | 107 public: |
125 TestCallback() { | 108 TestCallback() { |
126 Initialize(new State()); | 109 Initialize(new State()); |
127 } | 110 } |
128 ~TestCallback() {} | 111 ~TestCallback() {} |
129 | 112 |
130 const NetworkErrorPtr& result() const { return result_; } | 113 const NetworkErrorPtr& result() const { return result_; } |
131 | 114 |
132 private: | 115 private: |
133 struct State: public StateBase { | 116 struct State: public StateBase { |
134 virtual ~State() {} | 117 ~State() override {} |
135 | 118 |
136 virtual void Run(NetworkErrorPtr result) const override { | 119 void Run(NetworkErrorPtr result) const override { |
137 if (test_callback_) { | 120 if (test_callback_) { |
138 TestCallback* callback = static_cast<TestCallback*>(test_callback_); | 121 TestCallback* callback = static_cast<TestCallback*>(test_callback_); |
139 callback->result_ = result.Pass(); | 122 callback->result_ = result.Pass(); |
140 } | 123 } |
141 NotifyRun(); | 124 NotifyRun(); |
142 } | 125 } |
143 }; | 126 }; |
144 | 127 |
145 NetworkErrorPtr result_; | 128 NetworkErrorPtr result_; |
146 }; | 129 }; |
147 | 130 |
148 class TestCallbackWithAddress | 131 class TestCallbackWithAddress |
149 : public TestCallbackBase<Callback<void(NetworkErrorPtr, NetAddressPtr)>> { | 132 : public TestCallbackBase<Callback<void(NetworkErrorPtr, NetAddressPtr)>> { |
150 public: | 133 public: |
151 TestCallbackWithAddress() { | 134 TestCallbackWithAddress() { |
152 Initialize(new State()); | 135 Initialize(new State()); |
153 } | 136 } |
154 ~TestCallbackWithAddress() {} | 137 ~TestCallbackWithAddress() {} |
155 | 138 |
156 const NetworkErrorPtr& result() const { return result_; } | 139 const NetworkErrorPtr& result() const { return result_; } |
157 const NetAddressPtr& net_address() const { return net_address_; } | 140 const NetAddressPtr& net_address() const { return net_address_; } |
158 | 141 |
159 private: | 142 private: |
160 struct State : public StateBase { | 143 struct State : public StateBase { |
161 virtual ~State() {} | 144 ~State() override {} |
162 | 145 |
163 virtual void Run(NetworkErrorPtr result, | 146 void Run(NetworkErrorPtr result, NetAddressPtr net_address) const override { |
164 NetAddressPtr net_address) const override { | |
165 if (test_callback_) { | 147 if (test_callback_) { |
166 TestCallbackWithAddress* callback = | 148 TestCallbackWithAddress* callback = |
167 static_cast<TestCallbackWithAddress*>(test_callback_); | 149 static_cast<TestCallbackWithAddress*>(test_callback_); |
168 callback->result_ = result.Pass(); | 150 callback->result_ = result.Pass(); |
169 callback->net_address_ = net_address.Pass(); | 151 callback->net_address_ = net_address.Pass(); |
170 } | 152 } |
171 NotifyRun(); | 153 NotifyRun(); |
172 } | 154 } |
173 }; | 155 }; |
174 | 156 |
175 NetworkErrorPtr result_; | 157 NetworkErrorPtr result_; |
176 NetAddressPtr net_address_; | 158 NetAddressPtr net_address_; |
177 }; | 159 }; |
178 | 160 |
179 class TestCallbackWithUint32 | 161 class TestCallbackWithUint32 |
180 : public TestCallbackBase<Callback<void(uint32_t)>> { | 162 : public TestCallbackBase<Callback<void(uint32_t)>> { |
181 public: | 163 public: |
182 TestCallbackWithUint32() : result_(0) { | 164 TestCallbackWithUint32() : result_(0) { |
183 Initialize(new State()); | 165 Initialize(new State()); |
184 } | 166 } |
185 ~TestCallbackWithUint32() {} | 167 ~TestCallbackWithUint32() {} |
186 | 168 |
187 uint32_t result() const { return result_; } | 169 uint32_t result() const { return result_; } |
188 | 170 |
189 private: | 171 private: |
190 struct State : public StateBase { | 172 struct State : public StateBase { |
191 virtual ~State() {} | 173 ~State() override {} |
192 | 174 |
193 virtual void Run(uint32_t result) const override { | 175 void Run(uint32_t result) const override { |
194 if (test_callback_) { | 176 if (test_callback_) { |
195 TestCallbackWithUint32* callback = | 177 TestCallbackWithUint32* callback = |
196 static_cast<TestCallbackWithUint32*>(test_callback_); | 178 static_cast<TestCallbackWithUint32*>(test_callback_); |
197 callback->result_ = result; | 179 callback->result_ = result; |
198 } | 180 } |
199 NotifyRun(); | 181 NotifyRun(); |
200 } | 182 } |
201 }; | 183 }; |
202 | 184 |
203 uint32_t result_; | 185 uint32_t result_; |
204 }; | 186 }; |
205 | 187 |
| 188 class TestReceiveCallback |
| 189 : public TestCallbackBase< |
| 190 Callback<void(NetworkErrorPtr, NetAddressPtr, Array<uint8_t>)>> { |
| 191 public: |
| 192 TestReceiveCallback() { |
| 193 Initialize(new State()); |
| 194 } |
| 195 ~TestReceiveCallback() {} |
| 196 |
| 197 const NetworkErrorPtr& result() const { return result_; } |
| 198 const NetAddressPtr& src_addr() const { return src_addr_; } |
| 199 const Array<uint8_t>& data() const { return data_; } |
| 200 |
| 201 private: |
| 202 struct State : public StateBase { |
| 203 ~State() override {} |
| 204 |
| 205 void Run(NetworkErrorPtr result, |
| 206 NetAddressPtr src_addr, |
| 207 Array<uint8_t> data) const override { |
| 208 if (test_callback_) { |
| 209 TestReceiveCallback* callback = |
| 210 static_cast<TestReceiveCallback*>(test_callback_); |
| 211 callback->result_ = result.Pass(); |
| 212 callback->src_addr_ = src_addr.Pass(); |
| 213 callback->data_ = data.Pass(); |
| 214 } |
| 215 NotifyRun(); |
| 216 } |
| 217 }; |
| 218 |
| 219 NetworkErrorPtr result_; |
| 220 NetAddressPtr src_addr_; |
| 221 Array<uint8_t> data_; |
| 222 }; |
| 223 |
206 class UDPSocketTest : public testing::Test { | 224 class UDPSocketTest : public testing::Test { |
207 public: | 225 public: |
208 UDPSocketTest() {} | 226 UDPSocketTest() {} |
209 virtual ~UDPSocketTest() {} | 227 ~UDPSocketTest() override {} |
210 | 228 |
211 virtual void SetUp() override { | 229 void SetUp() override { |
212 test_helper_.Init(); | 230 test_helper_.Init(); |
213 | 231 |
214 test_helper_.application_manager()->ConnectToService( | 232 test_helper_.application_manager()->ConnectToService( |
215 GURL("mojo:mojo_network_service"), &network_service_); | 233 GURL("mojo:mojo_network_service"), &network_service_); |
216 | 234 |
217 network_service_->CreateUDPSocket(GetProxy(&udp_socket_)); | 235 network_service_->CreateUDPSocket(GetProxy(&udp_socket_)); |
218 udp_socket_.set_client(&udp_socket_client_); | 236 udp_socket_.set_client(&udp_socket_client_); |
219 } | 237 } |
220 | 238 |
221 protected: | 239 protected: |
222 struct ReceiveResult { | 240 struct ReceiveResult { |
223 NetworkErrorPtr result; | 241 NetworkErrorPtr result; |
224 NetAddressPtr addr; | 242 NetAddressPtr addr; |
225 Array<uint8_t> data; | 243 Array<uint8_t> data; |
226 }; | 244 }; |
227 | 245 |
228 class UDPSocketClientImpl : public UDPSocketClient { | 246 class UDPSocketClientImpl : public UDPSocketClient { |
229 public: | 247 public: |
230 | 248 |
231 UDPSocketClientImpl() : run_loop_(nullptr), expected_receive_count_(0) {} | 249 UDPSocketClientImpl() : run_loop_(nullptr), expected_receive_count_(0) {} |
232 | 250 |
233 virtual ~UDPSocketClientImpl() { | 251 ~UDPSocketClientImpl() override { |
234 while (!results_.empty()) { | 252 while (!results_.empty()) { |
235 delete results_.front(); | 253 delete results_.front(); |
236 results_.pop(); | 254 results_.pop(); |
237 } | 255 } |
238 } | 256 } |
239 | 257 |
240 virtual void OnReceived(NetworkErrorPtr result, | 258 void OnReceived(NetworkErrorPtr result, |
241 NetAddressPtr src_addr, | 259 NetAddressPtr src_addr, |
242 Array<uint8_t> data) override { | 260 Array<uint8_t> data) override { |
243 ReceiveResult* entry = new ReceiveResult(); | 261 ReceiveResult* entry = new ReceiveResult(); |
244 entry->result = result.Pass(); | 262 entry->result = result.Pass(); |
245 entry->addr = src_addr.Pass(); | 263 entry->addr = src_addr.Pass(); |
246 entry->data = data.Pass(); | 264 entry->data = data.Pass(); |
247 | 265 |
248 results_.push(entry); | 266 results_.push(entry); |
249 | 267 |
250 if (results_.size() == expected_receive_count_ && run_loop_) { | 268 if (results_.size() == expected_receive_count_ && run_loop_) { |
251 expected_receive_count_ = 0; | 269 expected_receive_count_ = 0; |
252 run_loop_->Quit(); | 270 run_loop_->Quit(); |
(...skipping 96 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
349 | 367 |
350 UDPSocketPtr client_socket; | 368 UDPSocketPtr client_socket; |
351 network_service_->CreateUDPSocket(GetProxy(&client_socket)); | 369 network_service_->CreateUDPSocket(GetProxy(&client_socket)); |
352 | 370 |
353 TestCallbackWithAddress callback2; | 371 TestCallbackWithAddress callback2; |
354 client_socket->Bind(GetLocalHostWithAnyPort(), callback2.callback()); | 372 client_socket->Bind(GetLocalHostWithAnyPort(), callback2.callback()); |
355 callback2.WaitForResult(); | 373 callback2.WaitForResult(); |
356 ASSERT_EQ(net::OK, callback2.result()->code); | 374 ASSERT_EQ(net::OK, callback2.result()->code); |
357 ASSERT_NE(0u, callback2.net_address()->ipv4->port); | 375 ASSERT_NE(0u, callback2.net_address()->ipv4->port); |
358 | 376 |
| 377 NetAddressPtr client_addr = callback2.net_address().Clone(); |
| 378 |
359 const size_t kDatagramCount = 6; | 379 const size_t kDatagramCount = 6; |
360 const size_t kDatagramSize = 255; | 380 const size_t kDatagramSize = 255; |
361 udp_socket_->ReceiveMore(kDatagramCount); | 381 udp_socket_->ReceiveMore(kDatagramCount); |
362 | 382 |
363 for (size_t i = 0; i < kDatagramCount; ++i) { | 383 for (size_t i = 0; i < kDatagramCount; ++i) { |
364 TestCallback callback; | 384 TestCallback callback; |
365 client_socket->SendTo( | 385 client_socket->SendTo( |
366 server_addr.Clone(), | 386 server_addr.Clone(), |
367 CreateTestMessage(static_cast<uint8_t>(i), kDatagramSize), | 387 CreateTestMessage(static_cast<uint8_t>(i), kDatagramSize), |
368 callback.callback()); | 388 callback.callback()); |
369 callback.WaitForResult(); | 389 callback.WaitForResult(); |
370 EXPECT_EQ(255, callback.result()->code); | 390 EXPECT_EQ(255, callback.result()->code); |
371 } | 391 } |
372 | 392 |
373 WaitForReceiveResults(kDatagramCount); | 393 WaitForReceiveResults(kDatagramCount); |
374 for (size_t i = 0; i < kDatagramCount; ++i) { | 394 for (size_t i = 0; i < kDatagramCount; ++i) { |
375 scoped_ptr<ReceiveResult> result(GetReceiveResults()->front()); | 395 scoped_ptr<ReceiveResult> result(GetReceiveResults()->front()); |
376 GetReceiveResults()->pop(); | 396 GetReceiveResults()->pop(); |
377 | 397 |
378 EXPECT_EQ(static_cast<int>(kDatagramSize), result->result->code); | 398 EXPECT_EQ(static_cast<int>(kDatagramSize), result->result->code); |
379 EXPECT_TRUE(AreEqualArrays( | 399 EXPECT_TRUE(result->addr.Equals(client_addr)); |
380 CreateTestMessage(static_cast<uint8_t>(i), kDatagramSize), | 400 EXPECT_TRUE(result->data.Equals( |
381 result->data)); | 401 CreateTestMessage(static_cast<uint8_t>(i), kDatagramSize))); |
382 } | 402 } |
383 } | 403 } |
384 | 404 |
| 405 TEST_F(UDPSocketTest, TestUDPSocketWrapper) { |
| 406 UDPSocketWrapper udp_socket(udp_socket_.Pass(), 4, 4); |
| 407 |
| 408 TestCallbackWithAddress callback1; |
| 409 udp_socket.Bind(GetLocalHostWithAnyPort(), callback1.callback()); |
| 410 callback1.WaitForResult(); |
| 411 ASSERT_EQ(net::OK, callback1.result()->code); |
| 412 ASSERT_NE(0u, callback1.net_address()->ipv4->port); |
| 413 |
| 414 NetAddressPtr server_addr = callback1.net_address().Clone(); |
| 415 |
| 416 UDPSocketPtr raw_client_socket; |
| 417 network_service_->CreateUDPSocket(GetProxy(&raw_client_socket)); |
| 418 UDPSocketWrapper client_socket(raw_client_socket.Pass(), 4, 4); |
| 419 |
| 420 TestCallbackWithAddress callback2; |
| 421 client_socket.Bind(GetLocalHostWithAnyPort(), callback2.callback()); |
| 422 callback2.WaitForResult(); |
| 423 ASSERT_EQ(net::OK, callback2.result()->code); |
| 424 ASSERT_NE(0u, callback2.net_address()->ipv4->port); |
| 425 |
| 426 NetAddressPtr client_addr = callback2.net_address().Clone(); |
| 427 |
| 428 const size_t kDatagramCount = 16; |
| 429 const size_t kDatagramSize = 255; |
| 430 |
| 431 for (size_t i = 1; i < kDatagramCount; ++i) { |
| 432 scoped_ptr<TestCallback[]> send_callbacks(new TestCallback[i]); |
| 433 scoped_ptr<TestReceiveCallback[]> receive_callbacks( |
| 434 new TestReceiveCallback[i]); |
| 435 |
| 436 for (size_t j = 0; j < i; ++j) { |
| 437 client_socket.SendTo( |
| 438 server_addr.Clone(), |
| 439 CreateTestMessage(static_cast<uint8_t>(j), kDatagramSize), |
| 440 send_callbacks[j].callback()); |
| 441 |
| 442 udp_socket.ReceiveFrom(receive_callbacks[j].callback()); |
| 443 } |
| 444 |
| 445 receive_callbacks[i - 1].WaitForResult(); |
| 446 |
| 447 for (size_t j = 0; j < i; ++j) { |
| 448 EXPECT_EQ(static_cast<int>(kDatagramSize), |
| 449 receive_callbacks[j].result()->code); |
| 450 EXPECT_TRUE(receive_callbacks[j].src_addr().Equals(client_addr)); |
| 451 EXPECT_TRUE(receive_callbacks[j].data().Equals( |
| 452 CreateTestMessage(static_cast<uint8_t>(j), kDatagramSize))); |
| 453 } |
| 454 } |
| 455 } |
| 456 |
385 } // namespace service | 457 } // namespace service |
386 } // namespace mojo | 458 } // namespace mojo |
OLD | NEW |