OLD | NEW |
| (Empty) |
1 // Copyright 2017 The Chromium Authors. All rights reserved. | |
2 // Use of this source code is governed by a BSD-style license that can be | |
3 // found in the LICENSE file. | |
4 | |
5 #include "components/translate/core/browser/ranker_model_loader.h" | |
6 | |
7 #include <deque> | |
8 #include <initializer_list> | |
9 #include <memory> | |
10 #include <vector> | |
11 | |
12 #include "base/files/file_util.h" | |
13 #include "base/files/scoped_temp_dir.h" | |
14 #include "base/memory/ptr_util.h" | |
15 #include "base/memory/ref_counted.h" | |
16 #include "base/message_loop/message_loop.h" | |
17 #include "base/run_loop.h" | |
18 #include "base/strings/stringprintf.h" | |
19 #include "base/task_scheduler/post_task.h" | |
20 #include "base/task_scheduler/task_scheduler.h" | |
21 #include "base/test/scoped_feature_list.h" | |
22 #include "base/test/scoped_task_scheduler.h" | |
23 #include "base/test/test_simple_task_runner.h" | |
24 #include "base/threading/thread_task_runner_handle.h" | |
25 #include "components/metrics/proto/translate_event.pb.h" | |
26 #include "components/prefs/scoped_user_pref_update.h" | |
27 #include "components/sync_preferences/testing_pref_service_syncable.h" | |
28 #include "components/translate/core/browser/proto/ranker_model.pb.h" | |
29 #include "components/translate/core/browser/proto/translate_ranker_model.pb.h" | |
30 #include "components/translate/core/browser/ranker_model.h" | |
31 #include "components/translate/core/browser/translate_download_manager.h" | |
32 #include "components/translate/core/browser/translate_prefs.h" | |
33 #include "net/url_request/test_url_fetcher_factory.h" | |
34 #include "net/url_request/url_request_test_util.h" | |
35 #include "testing/gtest/include/gtest/gtest.h" | |
36 | |
37 namespace { | |
38 | |
39 using base::TaskScheduler; | |
40 using chrome_intelligence::RankerModel; | |
41 using translate::RankerModelLoader; | |
42 using translate::RankerModelStatus; | |
43 using translate::TranslateDownloadManager; | |
44 | |
45 const char kInvalidModelData[] = "not a valid model"; | |
46 const int kInvalidModelSize = sizeof(kInvalidModelData) - 1; | |
47 | |
48 class RankerModelLoaderTest : public ::testing::Test { | |
49 protected: | |
50 RankerModelLoaderTest(); | |
51 | |
52 void SetUp() override; | |
53 | |
54 void TearDown() override; | |
55 | |
56 // Returns a copy of |model|. | |
57 static std::unique_ptr<RankerModel> Clone(const RankerModel& model); | |
58 | |
59 // Returns true if |m1| and |m2| are identical. | |
60 static bool IsEqual(const RankerModel& m1, const RankerModel& m2); | |
61 | |
62 // Returns true if |m1| and |m2| are identical modulo metadata. | |
63 static bool IsEquivalent(const RankerModel& m1, const RankerModel& m2); | |
64 | |
65 // Helper method to drive the loader for |model_path| and |model_url|. | |
66 bool DoLoaderTest(const base::FilePath& model_path, const GURL& model_url); | |
67 | |
68 // Initialize the "remote" model data used for testing. | |
69 void InitRemoteModels(); | |
70 | |
71 // Initialize the "local" model data used for testing. | |
72 void InitLocalModels(); | |
73 | |
74 // Helper method used by InitRemoteModels() and InitLocalModels(). | |
75 void InitModel(const GURL& model_url, | |
76 const base::Time& last_modified, | |
77 const base::TimeDelta& cache_duration, | |
78 RankerModel* model); | |
79 | |
80 // Save |model| to |model_path|. Used by InitRemoteModels() and | |
81 // InitLocalModels() | |
82 void SaveModel(const RankerModel& model, const base::FilePath& model_path); | |
83 | |
84 // Implements RankerModelLoader's ValidateModelCallback interface. | |
85 RankerModelStatus ValidateModel(const RankerModel& model); | |
86 | |
87 // Implements RankerModelLoader's OnModelAvailableCallback interface. | |
88 void OnModelAvailable(std::unique_ptr<RankerModel> model); | |
89 | |
90 // Sets up the task scheduling/task-runner environment for each test. | |
91 base::test::ScopedTaskScheduler scoped_task_scheduler_; | |
92 | |
93 // Override the default URL fetcher to return custom responses for tests. | |
94 net::FakeURLFetcherFactory url_fetcher_factory_; | |
95 | |
96 // Temporary directory for model files. | |
97 base::ScopedTempDir scoped_temp_dir_; | |
98 | |
99 // Cache and reset the application locale for each test. | |
100 std::string locale_; | |
101 | |
102 // Used to initialize the translate download manager. | |
103 scoped_refptr<net::TestURLRequestContextGetter> request_context_; | |
104 | |
105 // A queue of responses to return from Validate(). If empty, validate will | |
106 // return 'OK'. | |
107 std::deque<RankerModelStatus> validate_model_response_; | |
108 | |
109 // A cached to remember the model validation calls. | |
110 std::vector<std::unique_ptr<RankerModel>> validated_models_; | |
111 | |
112 // A cache to remember the OnModelAvailable calls. | |
113 std::vector<std::unique_ptr<RankerModel>> available_models_; | |
114 | |
115 // Cached model file paths. | |
116 base::FilePath local_model_path_; | |
117 base::FilePath expired_model_path_; | |
118 base::FilePath invalid_model_path_; | |
119 | |
120 // Model URLS. | |
121 GURL remote_model_url_; | |
122 GURL invalid_model_url_; | |
123 GURL failed_model_url_; | |
124 | |
125 // Model Data. | |
126 RankerModel remote_model_; | |
127 RankerModel local_model_; | |
128 RankerModel expired_model_; | |
129 | |
130 private: | |
131 DISALLOW_COPY_AND_ASSIGN(RankerModelLoaderTest); | |
132 }; | |
133 | |
134 RankerModelLoaderTest::RankerModelLoaderTest() | |
135 : url_fetcher_factory_(nullptr) {} | |
136 | |
137 void RankerModelLoaderTest::SetUp() { | |
138 // Setup the translate download manager. | |
139 locale_ = TranslateDownloadManager::GetInstance()->application_locale(); | |
140 request_context_ = | |
141 new net::TestURLRequestContextGetter(base::ThreadTaskRunnerHandle::Get()); | |
142 TranslateDownloadManager::GetInstance()->set_application_locale("fr-CA"); | |
143 TranslateDownloadManager::GetInstance()->set_request_context( | |
144 request_context_.get()); | |
145 | |
146 ASSERT_TRUE(scoped_temp_dir_.CreateUniqueTempDir()); | |
147 const auto& temp_dir_path = scoped_temp_dir_.GetPath(); | |
148 | |
149 // Setup the model file paths. | |
150 local_model_path_ = temp_dir_path.AppendASCII("local_model.bin"); | |
151 expired_model_path_ = temp_dir_path.AppendASCII("expired_model.bin"); | |
152 invalid_model_path_ = temp_dir_path.AppendASCII("invalid_model.bin"); | |
153 | |
154 // Setup the model URLs. | |
155 remote_model_url_ = GURL("https://some.url.net/good.model.bin"); | |
156 invalid_model_url_ = GURL("https://some.url.net/bad.model.bin"); | |
157 failed_model_url_ = GURL("https://some.url.net/fail"); | |
158 | |
159 // Initialize the model data. | |
160 ASSERT_NO_FATAL_FAILURE(InitRemoteModels()); | |
161 ASSERT_NO_FATAL_FAILURE(InitLocalModels()); | |
162 } | |
163 | |
164 void RankerModelLoaderTest::TearDown() { | |
165 base::RunLoop().RunUntilIdle(); | |
166 TranslateDownloadManager::GetInstance()->set_application_locale(locale_); | |
167 TranslateDownloadManager::GetInstance()->set_request_context(nullptr); | |
168 } | |
169 | |
170 // static | |
171 std::unique_ptr<RankerModel> RankerModelLoaderTest::Clone( | |
172 const RankerModel& model) { | |
173 auto copy = base::MakeUnique<RankerModel>(); | |
174 *copy->mutable_proto() = model.proto(); | |
175 return copy; | |
176 } | |
177 | |
178 // static | |
179 bool RankerModelLoaderTest::IsEqual(const RankerModel& m1, | |
180 const RankerModel& m2) { | |
181 return m1.SerializeAsString() == m2.SerializeAsString(); | |
182 } | |
183 | |
184 // static | |
185 bool RankerModelLoaderTest::IsEquivalent(const RankerModel& m1, | |
186 const RankerModel& m2) { | |
187 auto copy_m1 = Clone(m1); | |
188 copy_m1->mutable_proto()->mutable_metadata()->Clear(); | |
189 | |
190 auto copy_m2 = Clone(m2); | |
191 copy_m2->mutable_proto()->mutable_metadata()->Clear(); | |
192 | |
193 return IsEqual(*copy_m1, *copy_m2); | |
194 } | |
195 | |
196 bool RankerModelLoaderTest::DoLoaderTest(const base::FilePath& model_path, | |
197 const GURL& model_url) { | |
198 auto loader = base::MakeUnique<RankerModelLoader>( | |
199 base::Bind(&RankerModelLoaderTest::ValidateModel, base::Unretained(this)), | |
200 base::Bind(&RankerModelLoaderTest::OnModelAvailable, | |
201 base::Unretained(this)), | |
202 model_path, model_url, "RankerModelLoaderTest"); | |
203 loader->NotifyOfRankerActivity(); | |
204 base::RunLoop().RunUntilIdle(); | |
205 | |
206 return true; | |
207 } | |
208 | |
209 void RankerModelLoaderTest::InitRemoteModels() { | |
210 InitModel(remote_model_url_, base::Time(), base::TimeDelta(), &remote_model_); | |
211 url_fetcher_factory_.SetFakeResponse( | |
212 remote_model_url_, remote_model_.SerializeAsString(), net::HTTP_OK, | |
213 net::URLRequestStatus::SUCCESS); | |
214 url_fetcher_factory_.SetFakeResponse(invalid_model_url_, kInvalidModelData, | |
215 net::HTTP_OK, | |
216 net::URLRequestStatus::SUCCESS); | |
217 url_fetcher_factory_.SetFakeResponse(failed_model_url_, "", | |
218 net::HTTP_INTERNAL_SERVER_ERROR, | |
219 net::URLRequestStatus::FAILED); | |
220 } | |
221 | |
222 void RankerModelLoaderTest::InitLocalModels() { | |
223 InitModel(remote_model_url_, base::Time::Now(), base::TimeDelta::FromDays(30), | |
224 &local_model_); | |
225 InitModel(remote_model_url_, | |
226 base::Time::Now() - base::TimeDelta::FromDays(60), | |
227 base::TimeDelta::FromDays(30), &expired_model_); | |
228 SaveModel(local_model_, local_model_path_); | |
229 SaveModel(expired_model_, expired_model_path_); | |
230 ASSERT_EQ(base::WriteFile(invalid_model_path_, kInvalidModelData, | |
231 kInvalidModelSize), | |
232 kInvalidModelSize); | |
233 } | |
234 | |
235 void RankerModelLoaderTest::InitModel(const GURL& model_url, | |
236 const base::Time& last_modified, | |
237 const base::TimeDelta& cache_duration, | |
238 RankerModel* model) { | |
239 ASSERT_TRUE(model != nullptr); | |
240 model->mutable_proto()->Clear(); | |
241 | |
242 auto* metadata = model->mutable_proto()->mutable_metadata(); | |
243 if (!model_url.is_empty()) | |
244 metadata->set_source(model_url.spec()); | |
245 if (!last_modified.is_null()) { | |
246 auto last_modified_sec = (last_modified - base::Time()).InSeconds(); | |
247 metadata->set_last_modified_sec(last_modified_sec); | |
248 } | |
249 if (!cache_duration.is_zero()) | |
250 metadata->set_cache_duration_sec(cache_duration.InSeconds()); | |
251 | |
252 auto* translate = model->mutable_proto()->mutable_translate(); | |
253 translate->set_version(1); | |
254 | |
255 auto* logit = translate->mutable_logistic_regression_model(); | |
256 logit->set_bias(0.1f); | |
257 logit->set_accept_ratio_weight(0.2f); | |
258 logit->set_decline_ratio_weight(0.3f); | |
259 logit->set_ignore_ratio_weight(0.4f); | |
260 } | |
261 | |
262 void RankerModelLoaderTest::SaveModel(const RankerModel& model, | |
263 const base::FilePath& model_path) { | |
264 std::string model_str = model.SerializeAsString(); | |
265 ASSERT_EQ(base::WriteFile(model_path, model_str.data(), model_str.size()), | |
266 static_cast<int>(model_str.size())); | |
267 } | |
268 | |
269 RankerModelStatus RankerModelLoaderTest::ValidateModel( | |
270 const RankerModel& model) { | |
271 validated_models_.push_back(Clone(model)); | |
272 RankerModelStatus response = RankerModelStatus::OK; | |
273 if (!validate_model_response_.empty()) { | |
274 response = validate_model_response_.front(); | |
275 validate_model_response_.pop_front(); | |
276 } | |
277 return response; | |
278 } | |
279 | |
280 void RankerModelLoaderTest::OnModelAvailable( | |
281 std::unique_ptr<RankerModel> model) { | |
282 available_models_.push_back(std::move(model)); | |
283 } | |
284 | |
285 } // namespace | |
286 | |
287 TEST_F(RankerModelLoaderTest, NoLocalOrRemoteModel) { | |
288 ASSERT_TRUE(DoLoaderTest(base::FilePath(), GURL())); | |
289 | |
290 EXPECT_EQ(0U, validated_models_.size()); | |
291 EXPECT_EQ(0U, available_models_.size()); | |
292 } | |
293 | |
294 TEST_F(RankerModelLoaderTest, BadLocalAndRemoteModel) { | |
295 ASSERT_TRUE(DoLoaderTest(invalid_model_path_, invalid_model_url_)); | |
296 | |
297 EXPECT_EQ(0U, validated_models_.size()); | |
298 EXPECT_EQ(0U, available_models_.size()); | |
299 } | |
300 | |
301 TEST_F(RankerModelLoaderTest, LoadFromFileOnly) { | |
302 EXPECT_TRUE(DoLoaderTest(local_model_path_, GURL())); | |
303 | |
304 ASSERT_EQ(1U, validated_models_.size()); | |
305 ASSERT_EQ(1U, available_models_.size()); | |
306 EXPECT_TRUE(IsEqual(*validated_models_[0], local_model_)); | |
307 EXPECT_TRUE(IsEqual(*available_models_[0], local_model_)); | |
308 } | |
309 | |
310 TEST_F(RankerModelLoaderTest, LoadFromFileSkipsDownload) { | |
311 ASSERT_TRUE(DoLoaderTest(local_model_path_, remote_model_url_)); | |
312 | |
313 ASSERT_EQ(1U, validated_models_.size()); | |
314 ASSERT_EQ(1U, available_models_.size()); | |
315 EXPECT_TRUE(IsEqual(*validated_models_[0], local_model_)); | |
316 EXPECT_TRUE(IsEqual(*available_models_[0], local_model_)); | |
317 } | |
318 | |
319 TEST_F(RankerModelLoaderTest, LoadFromFileAndBadUrl) { | |
320 ASSERT_TRUE(DoLoaderTest(local_model_path_, invalid_model_url_)); | |
321 ASSERT_EQ(1U, validated_models_.size()); | |
322 ASSERT_EQ(1U, available_models_.size()); | |
323 EXPECT_TRUE(IsEqual(*validated_models_[0], local_model_)); | |
324 EXPECT_TRUE(IsEqual(*available_models_[0], local_model_)); | |
325 } | |
326 | |
327 TEST_F(RankerModelLoaderTest, LoadFromURLOnly) { | |
328 ASSERT_TRUE(DoLoaderTest(base::FilePath(), remote_model_url_)); | |
329 ASSERT_EQ(1U, validated_models_.size()); | |
330 ASSERT_EQ(1U, available_models_.size()); | |
331 EXPECT_TRUE(IsEquivalent(*validated_models_[0], remote_model_)); | |
332 EXPECT_TRUE(IsEquivalent(*available_models_[0], remote_model_)); | |
333 } | |
334 | |
335 TEST_F(RankerModelLoaderTest, LoadFromExpiredFileTriggersDownload) { | |
336 ASSERT_TRUE(DoLoaderTest(expired_model_path_, remote_model_url_)); | |
337 ASSERT_EQ(2U, validated_models_.size()); | |
338 ASSERT_EQ(2U, available_models_.size()); | |
339 EXPECT_TRUE(IsEquivalent(*validated_models_[0], local_model_)); | |
340 EXPECT_TRUE(IsEquivalent(*available_models_[0], local_model_)); | |
341 EXPECT_TRUE(IsEquivalent(*validated_models_[1], remote_model_)); | |
342 EXPECT_TRUE(IsEquivalent(*available_models_[1], remote_model_)); | |
343 } | |
344 | |
345 TEST_F(RankerModelLoaderTest, LoadFromBadFileTriggersDownload) { | |
346 ASSERT_TRUE(DoLoaderTest(invalid_model_path_, remote_model_url_)); | |
347 ASSERT_EQ(1U, validated_models_.size()); | |
348 ASSERT_EQ(1U, available_models_.size()); | |
349 EXPECT_TRUE(IsEquivalent(*validated_models_[0], remote_model_)); | |
350 EXPECT_TRUE(IsEquivalent(*available_models_[0], remote_model_)); | |
351 } | |
352 | |
353 TEST_F(RankerModelLoaderTest, IncompatibleCachedFileTriggersDownload) { | |
354 validate_model_response_.push_back(RankerModelStatus::INCOMPATIBLE); | |
355 | |
356 ASSERT_TRUE(DoLoaderTest(local_model_path_, remote_model_url_)); | |
357 ASSERT_EQ(2U, validated_models_.size()); | |
358 ASSERT_EQ(1U, available_models_.size()); | |
359 EXPECT_TRUE(IsEquivalent(*validated_models_[0], local_model_)); | |
360 EXPECT_TRUE(IsEquivalent(*validated_models_[1], remote_model_)); | |
361 EXPECT_TRUE(IsEquivalent(*available_models_[0], remote_model_)); | |
362 } | |
363 | |
364 TEST_F(RankerModelLoaderTest, IncompatibleDownloadedFileKeepsExpired) { | |
365 validate_model_response_.push_back(RankerModelStatus::OK); | |
366 validate_model_response_.push_back(RankerModelStatus::INCOMPATIBLE); | |
367 | |
368 ASSERT_TRUE(DoLoaderTest(expired_model_path_, remote_model_url_)); | |
369 ASSERT_EQ(2U, validated_models_.size()); | |
370 ASSERT_EQ(1U, available_models_.size()); | |
371 EXPECT_TRUE(IsEquivalent(*validated_models_[0], local_model_)); | |
372 EXPECT_TRUE(IsEquivalent(*validated_models_[1], remote_model_)); | |
373 EXPECT_TRUE(IsEquivalent(*available_models_[0], local_model_)); | |
374 } | |
OLD | NEW |