Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(493)

Unified Diff: components/translate/core/browser/translate_ranker_impl.cc

Issue 2565873002: [translate] Add translate ranker model loader. (Closed)
Patch Set: comments from sdefresne Created 3 years, 10 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
Index: components/translate/core/browser/translate_ranker_impl.cc
diff --git a/components/translate/core/browser/translate_ranker_impl.cc b/components/translate/core/browser/translate_ranker_impl.cc
new file mode 100644
index 0000000000000000000000000000000000000000..ed4a9b1f2df21957ae73d17330d31d48db803f2c
--- /dev/null
+++ b/components/translate/core/browser/translate_ranker_impl.cc
@@ -0,0 +1,282 @@
+// Copyright 2016 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "components/translate/core/browser/translate_ranker_impl.h"
+
+#include <cmath>
+
+#include "base/bind.h"
+#include "base/bind_helpers.h"
+#include "base/command_line.h"
+#include "base/files/file_path.h"
+#include "base/files/file_util.h"
+#include "base/memory/ptr_util.h"
+#include "base/metrics/histogram_macros.h"
+#include "base/profiler/scoped_tracker.h"
+#include "base/strings/string_number_conversions.h"
+#include "base/strings/string_util.h"
+#include "base/task_runner.h"
+#include "base/threading/thread_task_runner_handle.h"
+#include "components/metrics/proto/translate_event.pb.h"
+#include "components/translate/core/browser/proto/ranker_model.pb.h"
+#include "components/translate/core/browser/proto/translate_ranker_model.pb.h"
+#include "components/translate/core/browser/ranker_model.h"
+#include "components/translate/core/browser/translate_download_manager.h"
+#include "components/translate/core/browser/translate_prefs.h"
+#include "components/translate/core/browser/translate_url_fetcher.h"
+#include "components/translate/core/common/translate_switches.h"
+#include "components/variations/variations_associated_data.h"
+#include "url/gurl.h"
+
+namespace translate {
+
+namespace {
+
+using chrome_intelligence::RankerModel;
+using chrome_intelligence::RankerModelProto;
+using chrome_intelligence::TranslateRankerModel;
+
+const double kTranslationOfferThreshold = 0.5;
+
+const char kTranslateRankerModelFileName[] = "Translate Ranker Model";
+const char kUmaPrefix[] = "Translate.Ranker";
+const char kUnknown[] = "UNKNOWN";
+
+double Sigmoid(double x) {
+ return 1.0 / (1.0 + exp(-x));
+}
+
+double SafeRatio(int numerator, int denominator) {
+ return denominator ? (numerator / static_cast<double>(denominator)) : 0.0;
+}
+
+double ScoreComponent(const google::protobuf::Map<std::string, float>& weights,
+ const std::string& key) {
+ auto i = weights.find(base::ToLowerASCII(key));
+ if (i == weights.end())
+ i = weights.find(kUnknown);
+ return i == weights.end() ? 0.0 : i->second;
+}
+
+RankerModelStatus ValidateModel(const RankerModel& model) {
+ if (model.proto().model_case() != RankerModelProto::kTranslate)
+ return RankerModelStatus::VALIDATION_FAILED;
+
+ if (model.proto().translate().model_revision_case() !=
+ TranslateRankerModel::kLogisticRegressionModel) {
+ return RankerModelStatus::INCOMPATIBLE;
+ }
+
+ return RankerModelStatus::OK;
+}
+
+} // namespace
+
+const base::Feature kTranslateRankerQuery{"TranslateRankerQuery",
+ base::FEATURE_DISABLED_BY_DEFAULT};
+
+const base::Feature kTranslateRankerEnforcement{
+ "TranslateRankerEnforcement", base::FEATURE_DISABLED_BY_DEFAULT};
+
+const base::Feature kTranslateRankerLogging{"TranslateRankerLogging",
+ base::FEATURE_DISABLED_BY_DEFAULT};
+
+TranslateRankerFeatures::TranslateRankerFeatures() {}
+
+TranslateRankerFeatures::TranslateRankerFeatures(int accepted,
+ int denied,
+ int ignored,
+ const std::string& src,
+ const std::string& dst,
+ const std::string& cntry,
+ const std::string& locale)
+ : accepted_count(accepted),
+ denied_count(denied),
+ ignored_count(ignored),
+ total_count(accepted_count + denied_count + ignored_count),
+ src_lang(src),
+ dst_lang(dst),
+ country(cntry),
+ app_locale(locale),
+ accepted_ratio(SafeRatio(accepted_count, total_count)),
+ denied_ratio(SafeRatio(denied_count, total_count)),
+ ignored_ratio(SafeRatio(ignored_count, total_count)) {}
+
+TranslateRankerFeatures::TranslateRankerFeatures(const TranslatePrefs& prefs,
+ const std::string& src,
+ const std::string& dst,
+ const std::string& locale)
+ : TranslateRankerFeatures(prefs.GetTranslationAcceptedCount(src),
+ prefs.GetTranslationDeniedCount(src),
+ prefs.GetTranslationIgnoredCount(src),
+ src,
+ dst,
+ prefs.GetCountry(),
+ locale) {}
+
+TranslateRankerFeatures::~TranslateRankerFeatures() {}
+
+void TranslateRankerFeatures::WriteTo(std::ostream& stream) const {
+ stream << "src_lang='" << src_lang << "', "
+ << "dst_lang='" << dst_lang << "', "
+ << "country='" << country << "', "
+ << "app_locale='" << app_locale << "', "
+ << "accept_count=" << accepted_count << ", "
+ << "denied_count=" << denied_count << ", "
+ << "ignored_count=" << ignored_count << ", "
+ << "total_count=" << total_count << ", "
+ << "accept_ratio=" << accepted_ratio << ", "
+ << "decline_ratio=" << denied_ratio << ", "
+ << "ignore_ratio=" << ignored_ratio;
+}
+
+TranslateRankerImpl::TranslateRankerImpl(const base::FilePath& model_path,
+ const GURL& model_url)
+ : weak_ptr_factory_(this) {
+ model_loader_ = base::MakeUnique<RankerModelLoader>(
+ base::Bind(&ValidateModel),
+ base::Bind(&TranslateRankerImpl::OnModelAvailable,
+ weak_ptr_factory_.GetWeakPtr()),
+ model_path, model_url, kUmaPrefix);
+ model_loader_->Start();
+}
+
+TranslateRankerImpl::~TranslateRankerImpl() {}
+
+// static
+base::FilePath TranslateRankerImpl::GetModelPath(
+ const base::FilePath& data_dir) {
+ if (data_dir.empty())
+ return base::FilePath();
+
+ // Otherwise, look for the file in data dir.
+ return data_dir.AppendASCII(kTranslateRankerModelFileName);
+}
+
+// static
+GURL TranslateRankerImpl::GetModelURL() {
+ // Allow override of the ranker model URL from the command line.
+ std::string raw_url;
+ base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
+ if (command_line->HasSwitch(switches::kTranslateRankerModelURL)) {
+ raw_url =
+ command_line->GetSwitchValueASCII(switches::kTranslateRankerModelURL);
+ } else {
+ // Otherwise take the ranker model URL from the ranker query variation.
+ raw_url = variations::GetVariationParamValueByFeature(
+ kTranslateRankerQuery, switches::kTranslateRankerModelURL);
+ }
+
+ DVLOG(3) << switches::kTranslateRankerModelURL << " = " << raw_url;
+
+ return GURL(raw_url);
+}
+
+bool TranslateRankerImpl::IsLoggingEnabled() {
+ return base::FeatureList::IsEnabled(kTranslateRankerLogging);
+}
+
+bool TranslateRankerImpl::IsQueryEnabled() {
+ return base::FeatureList::IsEnabled(kTranslateRankerQuery);
+}
+
+bool TranslateRankerImpl::IsEnforcementEnabled() {
+ return base::FeatureList::IsEnabled(kTranslateRankerEnforcement);
+}
+
+int TranslateRankerImpl::GetModelVersion() const {
+ return model_ ? model_->proto().translate().version() : 0;
+}
+
+bool TranslateRankerImpl::ShouldOfferTranslation(
+ const TranslatePrefs& translate_prefs,
+ const std::string& src_lang,
+ const std::string& dst_lang) {
+ DCHECK(sequence_checker_.CalledOnValidSequence());
+ // The ranker is a gate in the "show a translation prompt" flow. To retain
+ // the pre-existing functionality, it defaults to returning true in the
+ // absence of a model or if enforcement is disabled. As this is ranker is
+ // subsumed into a more general assist ranker, this default will go away
+ // (or become False).
+ const bool kDefaultResponse = true;
+
+ // If we don't have a model, request one and return the default.
+ if (model_ == nullptr) {
+ return kDefaultResponse;
+ }
+
+ SCOPED_UMA_HISTOGRAM_TIMER("Translate.Ranker.Timer.ShouldOfferTranslation");
+
+ // TODO(rogerm): Remove ScopedTracker below once crbug.com/646711 is closed.
+ tracked_objects::ScopedTracker tracking_profile(
+ FROM_HERE_WITH_EXPLICIT_FUNCTION(
+ "646711 translate::TranslateRankerImpl::ShouldOfferTranslation"));
+
+ TranslateRankerFeatures features(
+ translate_prefs, src_lang, dst_lang,
+ TranslateDownloadManager::GetInstance()->application_locale());
+
+ double score = CalculateScore(features);
+
+ DVLOG(2) << "TranslateRankerImpl::ShouldOfferTranslation: "
+ << "Score = " << score << ", Features=[" << features << "]";
+
+ bool result = (score >= kTranslationOfferThreshold);
+
+ UMA_HISTOGRAM_BOOLEAN("Translate.Ranker.QueryResult", result);
+
+ return result;
+}
+
+double TranslateRankerImpl::CalculateScore(
+ const TranslateRankerFeatures& features) {
+ DCHECK(sequence_checker_.CalledOnValidSequence());
+ SCOPED_UMA_HISTOGRAM_TIMER("Translate.Ranker.Timer.CalculateScore");
+ DCHECK(model_ != nullptr);
+ const TranslateRankerModel::LogisticRegressionModel& logit =
+ model_->proto().translate().logistic_regression_model();
+
+ double dot_product =
+ (features.accepted_count * logit.accept_count_weight()) +
+ (features.denied_count * logit.decline_count_weight()) +
+ (features.ignored_count * logit.ignore_count_weight()) +
+ (features.accepted_ratio * logit.accept_ratio_weight()) +
+ (features.denied_ratio * logit.decline_ratio_weight()) +
+ (features.ignored_ratio * logit.ignore_ratio_weight()) +
+ ScoreComponent(logit.source_language_weight(), features.src_lang) +
+ ScoreComponent(logit.dest_language_weight(), features.dst_lang) +
+ ScoreComponent(logit.country_weight(), features.country) +
+ ScoreComponent(logit.locale_weight(), features.app_locale);
+
+ return Sigmoid(dot_product + logit.bias());
+}
+
+void TranslateRankerImpl::FlushTranslateEvents(
+ std::vector<metrics::TranslateEventProto>* events) {
+ DCHECK(sequence_checker_.CalledOnValidSequence());
+ DVLOG(3) << "Flushing translate ranker events.";
+ events->swap(event_cache_);
+ event_cache_.clear();
+}
+
+void TranslateRankerImpl::AddTranslateEvent(
+ const metrics::TranslateEventProto& event) {
+ DCHECK(sequence_checker_.CalledOnValidSequence());
+ DVLOG(3) << "Adding translate ranker event.";
+ if (IsLoggingEnabled())
+ event_cache_.push_back(event);
+}
+
+void TranslateRankerImpl::OnModelAvailable(std::unique_ptr<RankerModel> model) {
+ DCHECK(sequence_checker_.CalledOnValidSequence());
+ model_ = std::move(model);
+}
+
+} // namespace translate
+
+std::ostream& operator<<(std::ostream& stream,
+ const translate::TranslateRankerFeatures& features) {
+ features.WriteTo(stream);
+ return stream;
+}

Powered by Google App Engine
This is Rietveld 408576698