OLD | NEW |
1 // Copyright 2015 The Chromium Authors. All rights reserved. | 1 // Copyright 2015 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "components/dom_distiller/core/distillable_page_detector.h" | 5 #include "components/dom_distiller/core/distillable_page_detector.h" |
6 | 6 |
7 #include "base/logging.h" | 7 #include "base/logging.h" |
| 8 #include "grit/components_resources.h" |
| 9 #include "ui/base/resource/resource_bundle.h" |
8 | 10 |
9 namespace dom_distiller { | 11 namespace dom_distiller { |
10 | 12 |
| 13 const DistillablePageDetector* DistillablePageDetector::GetDefault() { |
| 14 static DistillablePageDetector* detector = nullptr; |
| 15 if (!detector) { |
| 16 std::string serialized_proto = |
| 17 ResourceBundle::GetSharedInstance() |
| 18 .GetRawDataResource(IDR_DISTILLABLE_PAGE_SERIALIZED_MODEL) |
| 19 .as_string(); |
| 20 scoped_ptr<AdaBoostProto> proto(new AdaBoostProto); |
| 21 CHECK(proto->ParseFromString(serialized_proto)); |
| 22 detector = new DistillablePageDetector(proto.Pass()); |
| 23 } |
| 24 return detector; |
| 25 } |
| 26 |
11 DistillablePageDetector::DistillablePageDetector( | 27 DistillablePageDetector::DistillablePageDetector( |
12 scoped_ptr<AdaBoostProto> proto) | 28 scoped_ptr<AdaBoostProto> proto) |
13 : proto_(proto.Pass()), threshold_(0.0) { | 29 : proto_(proto.Pass()), threshold_(0.0) { |
14 CHECK(proto_->num_stumps() == proto_->stump_size()); | 30 CHECK(proto_->num_stumps() == proto_->stump_size()); |
15 for (int i = 0; i < proto_->num_stumps(); ++i) { | 31 for (int i = 0; i < proto_->num_stumps(); ++i) { |
16 const StumpProto& stump = proto_->stump(i); | 32 const StumpProto& stump = proto_->stump(i); |
17 CHECK(stump.feature_number() >= 0); | 33 CHECK(stump.feature_number() >= 0); |
18 CHECK(stump.feature_number() < proto_->num_features()); | 34 CHECK(stump.feature_number() < proto_->num_features()); |
19 threshold_ += stump.weight() / 2.0; | 35 threshold_ += stump.weight() / 2.0; |
20 } | 36 } |
21 } | 37 } |
22 | 38 |
23 DistillablePageDetector::~DistillablePageDetector() { | 39 DistillablePageDetector::~DistillablePageDetector() { |
24 } | 40 } |
25 | 41 |
26 bool DistillablePageDetector::Classify( | 42 bool DistillablePageDetector::Classify( |
27 const std::vector<double>& features) const { | 43 const std::vector<double>& features) const { |
28 return Score(features) > threshold_; | 44 return Score(features) > threshold_; |
29 } | 45 } |
30 | 46 |
31 double DistillablePageDetector::Score( | 47 double DistillablePageDetector::Score( |
32 const std::vector<double>& features) const { | 48 const std::vector<double>& features) const { |
33 CHECK(features.size() == size_t(proto_->num_features())); | 49 if (features.size() != size_t(proto_->num_features())) { |
| 50 return 0.0; |
| 51 } |
34 double score = 0.0; | 52 double score = 0.0; |
35 for (int i = 0; i < proto_->num_stumps(); ++i) { | 53 for (int i = 0; i < proto_->num_stumps(); ++i) { |
36 const StumpProto& stump = proto_->stump(i); | 54 const StumpProto& stump = proto_->stump(i); |
37 if (features[stump.feature_number()] > stump.split()) { | 55 if (features[stump.feature_number()] > stump.split()) { |
38 score += stump.weight(); | 56 score += stump.weight(); |
39 } | 57 } |
40 } | 58 } |
41 return score; | 59 return score; |
42 } | 60 } |
43 | 61 |
44 double DistillablePageDetector::GetThreshold() const { | 62 double DistillablePageDetector::GetThreshold() const { |
45 return threshold_; | 63 return threshold_; |
46 } | 64 } |
47 | 65 |
48 } // namespace dom_distiller | 66 } // namespace dom_distiller |
OLD | NEW |