| OLD | NEW |
| 1 // Copyright 2015 The Chromium Authors. All rights reserved. | 1 // Copyright 2015 The Chromium Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
| 3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
| 4 | 4 |
| 5 #include "components/dom_distiller/core/distillable_page_detector.h" | 5 #include "components/dom_distiller/core/distillable_page_detector.h" |
| 6 | 6 |
| 7 #include "base/logging.h" | 7 #include "base/logging.h" |
| 8 #include "grit/components_resources.h" |
| 9 #include "ui/base/resource/resource_bundle.h" |
| 8 | 10 |
| 9 namespace dom_distiller { | 11 namespace dom_distiller { |
| 10 | 12 |
| 13 const DistillablePageDetector* DistillablePageDetector::GetDefault() { |
| 14 static DistillablePageDetector* detector = nullptr; |
| 15 if (!detector) { |
| 16 std::string serialized_proto = |
| 17 ResourceBundle::GetSharedInstance() |
| 18 .GetRawDataResource(IDR_DISTILLABLE_PAGE_SERIALIZED_MODEL) |
| 19 .as_string(); |
| 20 scoped_ptr<AdaBoostProto> proto(new AdaBoostProto); |
| 21 CHECK(proto->ParseFromString(serialized_proto)); |
| 22 detector = new DistillablePageDetector(proto.Pass()); |
| 23 } |
| 24 return detector; |
| 25 } |
| 26 |
| 11 DistillablePageDetector::DistillablePageDetector( | 27 DistillablePageDetector::DistillablePageDetector( |
| 12 scoped_ptr<AdaBoostProto> proto) | 28 scoped_ptr<AdaBoostProto> proto) |
| 13 : proto_(proto.Pass()), threshold_(0.0) { | 29 : proto_(proto.Pass()), threshold_(0.0) { |
| 14 CHECK(proto_->num_stumps() == proto_->stump_size()); | 30 CHECK(proto_->num_stumps() == proto_->stump_size()); |
| 15 for (int i = 0; i < proto_->num_stumps(); ++i) { | 31 for (int i = 0; i < proto_->num_stumps(); ++i) { |
| 16 const StumpProto& stump = proto_->stump(i); | 32 const StumpProto& stump = proto_->stump(i); |
| 17 CHECK(stump.feature_number() >= 0); | 33 CHECK(stump.feature_number() >= 0); |
| 18 CHECK(stump.feature_number() < proto_->num_features()); | 34 CHECK(stump.feature_number() < proto_->num_features()); |
| 19 threshold_ += stump.weight() / 2.0; | 35 threshold_ += stump.weight() / 2.0; |
| 20 } | 36 } |
| 21 } | 37 } |
| 22 | 38 |
| 23 DistillablePageDetector::~DistillablePageDetector() { | 39 DistillablePageDetector::~DistillablePageDetector() { |
| 24 } | 40 } |
| 25 | 41 |
| 26 bool DistillablePageDetector::Classify( | 42 bool DistillablePageDetector::Classify( |
| 27 const std::vector<double>& features) const { | 43 const std::vector<double>& features) const { |
| 28 return Score(features) > threshold_; | 44 return Score(features) > threshold_; |
| 29 } | 45 } |
| 30 | 46 |
| 31 double DistillablePageDetector::Score( | 47 double DistillablePageDetector::Score( |
| 32 const std::vector<double>& features) const { | 48 const std::vector<double>& features) const { |
| 33 CHECK(features.size() == size_t(proto_->num_features())); | 49 if (features.size() != size_t(proto_->num_features())) { |
| 50 return 0.0; |
| 51 } |
| 34 double score = 0.0; | 52 double score = 0.0; |
| 35 for (int i = 0; i < proto_->num_stumps(); ++i) { | 53 for (int i = 0; i < proto_->num_stumps(); ++i) { |
| 36 const StumpProto& stump = proto_->stump(i); | 54 const StumpProto& stump = proto_->stump(i); |
| 37 if (features[stump.feature_number()] > stump.split()) { | 55 if (features[stump.feature_number()] > stump.split()) { |
| 38 score += stump.weight(); | 56 score += stump.weight(); |
| 39 } | 57 } |
| 40 } | 58 } |
| 41 return score; | 59 return score; |
| 42 } | 60 } |
| 43 | 61 |
| 44 double DistillablePageDetector::GetThreshold() const { | 62 double DistillablePageDetector::GetThreshold() const { |
| 45 return threshold_; | 63 return threshold_; |
| 46 } | 64 } |
| 47 | 65 |
| 48 } // namespace dom_distiller | 66 } // namespace dom_distiller |
| OLD | NEW |