Index: components/translate/core/browser/translate_ranker_impl.cc |
diff --git a/components/translate/core/browser/translate_ranker_impl.cc b/components/translate/core/browser/translate_ranker_impl.cc |
index 6f34fa21f8ecd6065808caca0fd8c9eefbd9b5f9..d31d4718adabaa73fe5df4a2129f83f5a374f213 100644 |
--- a/components/translate/core/browser/translate_ranker_impl.cc |
+++ b/components/translate/core/browser/translate_ranker_impl.cc |
@@ -15,7 +15,6 @@ |
#include "base/metrics/histogram_macros.h" |
#include "base/metrics/metrics_hashes.h" |
#include "base/profiler/scoped_tracker.h" |
-#include "base/strings/string_number_conversions.h" |
#include "base/strings/string_util.h" |
#include "base/task_runner.h" |
#include "base/threading/thread_task_runner_handle.h" |
@@ -23,9 +22,6 @@ |
#include "components/translate/core/browser/proto/ranker_model.pb.h" |
#include "components/translate/core/browser/proto/translate_ranker_model.pb.h" |
#include "components/translate/core/browser/ranker_model.h" |
-#include "components/translate/core/browser/translate_download_manager.h" |
-#include "components/translate/core/browser/translate_prefs.h" |
-#include "components/translate/core/browser/translate_url_fetcher.h" |
#include "components/translate/core/common/translate_switches.h" |
#include "components/ukm/public/ukm_entry_builder.h" |
#include "components/ukm/public/ukm_recorder.h" |
@@ -106,17 +102,16 @@ TranslateRankerFeatures::TranslateRankerFeatures(int accepted, |
denied_ratio(SafeRatio(denied_count, total_count)), |
ignored_ratio(SafeRatio(ignored_count, total_count)) {} |
-TranslateRankerFeatures::TranslateRankerFeatures(const TranslatePrefs& prefs, |
- const std::string& src, |
- const std::string& dst, |
- const std::string& locale) |
- : TranslateRankerFeatures(prefs.GetTranslationAcceptedCount(src), |
- prefs.GetTranslationDeniedCount(src), |
- prefs.GetTranslationIgnoredCount(src), |
- src, |
- dst, |
- prefs.GetCountry(), |
- locale) {} |
+// TODO(hamelphi): Log locale in TranslateEventProtos. |
+TranslateRankerFeatures::TranslateRankerFeatures( |
+ const metrics::TranslateEventProto& translate_event) |
+ : TranslateRankerFeatures(translate_event.accept_count(), |
+ translate_event.decline_count(), |
+ translate_event.ignore_count(), |
+ translate_event.source_language(), |
+ translate_event.target_language(), |
+ translate_event.country(), |
+ "" /*locale*/) {} |
TranslateRankerFeatures::~TranslateRankerFeatures() {} |
@@ -195,9 +190,6 @@ uint32_t TranslateRankerImpl::GetModelVersion() const { |
} |
bool TranslateRankerImpl::ShouldOfferTranslation( |
- const TranslatePrefs& translate_prefs, |
- const std::string& src_lang, |
- const std::string& dst_lang, |
metrics::TranslateEventProto* translate_event) { |
DCHECK(sequence_checker_.CalledOnValidSequence()); |
// The ranker is a gate in the "show a translation prompt" flow. To retain |
@@ -234,11 +226,8 @@ bool TranslateRankerImpl::ShouldOfferTranslation( |
FROM_HERE_WITH_EXPLICIT_FUNCTION( |
"646711 translate::TranslateRankerImpl::ShouldOfferTranslation")); |
- TranslateRankerFeatures features( |
- translate_prefs, src_lang, dst_lang, |
- TranslateDownloadManager::GetInstance()->application_locale()); |
+ bool result = GetModelDecision(*translate_event); |
- bool result = GetModelDecision(features); |
UMA_HISTOGRAM_BOOLEAN("Translate.Ranker.QueryResult", result); |
translate_event->set_ranker_response( |
@@ -253,33 +242,38 @@ bool TranslateRankerImpl::ShouldOfferTranslation( |
} |
bool TranslateRankerImpl::GetModelDecision( |
- const TranslateRankerFeatures& features) { |
+ const metrics::TranslateEventProto& translate_event) { |
DCHECK(sequence_checker_.CalledOnValidSequence()); |
SCOPED_UMA_HISTOGRAM_TIMER("Translate.Ranker.Timer.CalculateScore"); |
DCHECK(model_ != nullptr); |
- const TranslateRankerModel::LogisticRegressionModel& logit = |
+ |
+ // TODO(hamelphi): consider deprecating TranslateRankerFeatures and move the |
+ // logic here. |
+ const TranslateRankerFeatures features(translate_event); |
+ |
+ const TranslateRankerModel::LogisticRegressionModel& lr_model = |
model_->proto().translate().logistic_regression_model(); |
double dot_product = |
- (features.accepted_count * logit.accept_count_weight()) + |
- (features.denied_count * logit.decline_count_weight()) + |
- (features.ignored_count * logit.ignore_count_weight()) + |
- (features.accepted_ratio * logit.accept_ratio_weight()) + |
- (features.denied_ratio * logit.decline_ratio_weight()) + |
- (features.ignored_ratio * logit.ignore_ratio_weight()) + |
- ScoreComponent(logit.source_language_weight(), features.src_lang) + |
- ScoreComponent(logit.target_language_weight(), features.dst_lang) + |
- ScoreComponent(logit.country_weight(), features.country) + |
- ScoreComponent(logit.locale_weight(), features.app_locale); |
- |
- double score = Sigmoid(dot_product + logit.bias()); |
+ (features.accepted_count * lr_model.accept_count_weight()) + |
+ (features.denied_count * lr_model.decline_count_weight()) + |
+ (features.ignored_count * lr_model.ignore_count_weight()) + |
+ (features.accepted_ratio * lr_model.accept_ratio_weight()) + |
+ (features.denied_ratio * lr_model.decline_ratio_weight()) + |
+ (features.ignored_ratio * lr_model.ignore_ratio_weight()) + |
+ ScoreComponent(lr_model.source_language_weight(), features.src_lang) + |
+ ScoreComponent(lr_model.target_language_weight(), features.dst_lang) + |
+ ScoreComponent(lr_model.country_weight(), features.country) + |
+ ScoreComponent(lr_model.locale_weight(), features.app_locale); |
+ |
+ double score = Sigmoid(dot_product + lr_model.bias()); |
DVLOG(2) << "TranslateRankerImpl::GetModelDecision: " |
<< "Score = " << score << ", Features=[" << features << "]"; |
float decision_threshold = kTranslationOfferDefaultThreshold; |
- if (logit.threshold() > 0) { |
- decision_threshold = logit.threshold(); |
+ if (lr_model.threshold() > 0) { |
+ decision_threshold = lr_model.threshold(); |
} |
return score >= decision_threshold; |