Chromium Code Reviews| Index: components/translate/core/browser/translate_ranker_impl.cc |
| diff --git a/components/translate/core/browser/translate_ranker_impl.cc b/components/translate/core/browser/translate_ranker_impl.cc |
| new file mode 100644 |
| index 0000000000000000000000000000000000000000..f7c3a15ab213a4a6be419b937a9ac57c9d0d22a5 |
| --- /dev/null |
| +++ b/components/translate/core/browser/translate_ranker_impl.cc |
| @@ -0,0 +1,281 @@ |
| +// Copyright 2016 The Chromium Authors. All rights reserved. |
|
groby-ooo-7-16
2017/02/23 00:01:35
Can you convince codereview this is mostly a copy
Roger McFarlane (Chromium)
2017/02/23 21:17:56
sent file specific diff out of band.
|
| +// Use of this source code is governed by a BSD-style license that can be |
| +// found in the LICENSE file. |
| + |
| +#include "components/translate/core/browser/translate_ranker_impl.h" |
| + |
| +#include <cmath> |
| + |
| +#include "base/bind.h" |
| +#include "base/bind_helpers.h" |
| +#include "base/command_line.h" |
| +#include "base/files/file_path.h" |
| +#include "base/files/file_util.h" |
| +#include "base/memory/ptr_util.h" |
| +#include "base/metrics/histogram_macros.h" |
| +#include "base/profiler/scoped_tracker.h" |
| +#include "base/strings/string_number_conversions.h" |
| +#include "base/strings/string_util.h" |
| +#include "base/task_runner.h" |
| +#include "base/threading/thread_task_runner_handle.h" |
| +#include "components/metrics/proto/translate_event.pb.h" |
| +#include "components/translate/core/browser/proto/ranker_model.pb.h" |
| +#include "components/translate/core/browser/proto/translate_ranker_model.pb.h" |
| +#include "components/translate/core/browser/ranker_model.h" |
| +#include "components/translate/core/browser/translate_download_manager.h" |
| +#include "components/translate/core/browser/translate_prefs.h" |
| +#include "components/translate/core/browser/translate_url_fetcher.h" |
| +#include "components/translate/core/common/translate_switches.h" |
| +#include "components/variations/variations_associated_data.h" |
| +#include "url/gurl.h" |
| + |
| +namespace translate { |
| + |
| +namespace { |
| + |
| +using chrome_intelligence::RankerModel; |
| +using chrome_intelligence::RankerModelProto; |
| +using chrome_intelligence::TranslateRankerModel; |
| + |
| +const double kTranslationOfferThreshold = 0.5; |
| + |
| +const char kTranslateRankerModelFileName[] = "Translate Ranker Model"; |
| +const char kUmaPrefix[] = "Translate.Ranker"; |
| +const char kUnknown[] = "UNKNOWN"; |
| + |
| +double Sigmoid(double x) { |
| + return 1.0 / (1.0 + exp(-x)); |
| +} |
| + |
| +double SafeRatio(int numerator, int denominator) { |
| + return denominator ? (numerator / static_cast<double>(denominator)) : 0.0; |
| +} |
| + |
| +double ScoreComponent(const google::protobuf::Map<std::string, float>& weights, |
| + const std::string& key) { |
| + auto i = weights.find(base::ToLowerASCII(key)); |
| + if (i == weights.end()) |
| + i = weights.find(kUnknown); |
| + return i == weights.end() ? 0.0 : i->second; |
| +} |
| + |
| +RankerModelStatus ValidateModel(const RankerModel& model) { |
| + if (model.proto().model_case() != RankerModelProto::kTranslate) |
| + return RankerModelStatus::VALIDATION_FAILED; |
| + |
| + if (model.proto().translate().model_revision_case() != |
| + TranslateRankerModel::kLogisticRegressionModel) { |
| + return RankerModelStatus::INCOMPATIBLE; |
| + } |
| + |
| + return RankerModelStatus::OK; |
| +} |
| + |
| +} // namespace |
| + |
| +const base::Feature kTranslateRankerQuery{"TranslateRankerQuery", |
| + base::FEATURE_DISABLED_BY_DEFAULT}; |
| + |
| +const base::Feature kTranslateRankerEnforcement{ |
| + "TranslateRankerEnforcement", base::FEATURE_DISABLED_BY_DEFAULT}; |
| + |
| +const base::Feature kTranslateRankerLogging{"TranslateRankerLogging", |
| + base::FEATURE_DISABLED_BY_DEFAULT}; |
| + |
| +TranslateRankerFeatures::TranslateRankerFeatures() {} |
| + |
| +TranslateRankerFeatures::TranslateRankerFeatures(int accepted, |
| + int denied, |
| + int ignored, |
| + const std::string& src, |
| + const std::string& dst, |
| + const std::string& cntry, |
| + const std::string& locale) |
| + : accepted_count(accepted), |
| + denied_count(denied), |
| + ignored_count(ignored), |
| + total_count(accepted_count + denied_count + ignored_count), |
| + src_lang(src), |
| + dst_lang(dst), |
| + country(cntry), |
| + app_locale(locale), |
| + accepted_ratio(SafeRatio(accepted_count, total_count)), |
| + denied_ratio(SafeRatio(denied_count, total_count)), |
| + ignored_ratio(SafeRatio(ignored_count, total_count)) {} |
| + |
| +TranslateRankerFeatures::TranslateRankerFeatures(const TranslatePrefs& prefs, |
| + const std::string& src, |
| + const std::string& dst, |
| + const std::string& locale) |
| + : TranslateRankerFeatures(prefs.GetTranslationAcceptedCount(src), |
| + prefs.GetTranslationDeniedCount(src), |
| + prefs.GetTranslationIgnoredCount(src), |
| + src, |
| + dst, |
| + prefs.GetCountry(), |
| + locale) {} |
| + |
| +TranslateRankerFeatures::~TranslateRankerFeatures() {} |
| + |
| +void TranslateRankerFeatures::WriteTo(std::ostream& stream) const { |
| + stream << "src_lang='" << src_lang << "', " |
| + << "dst_lang='" << dst_lang << "', " |
| + << "country='" << country << "', " |
| + << "app_locale='" << app_locale << "', " |
| + << "accept_count=" << accepted_count << ", " |
| + << "denied_count=" << denied_count << ", " |
| + << "ignored_count=" << ignored_count << ", " |
| + << "total_count=" << total_count << ", " |
| + << "accept_ratio=" << accepted_ratio << ", " |
| + << "decline_ratio=" << denied_ratio << ", " |
| + << "ignore_ratio=" << ignored_ratio; |
| +} |
| + |
| +TranslateRankerImpl::TranslateRankerImpl(const base::FilePath& model_path, |
| + const GURL& model_url) |
| + : weak_ptr_factory_(this) { |
| + model_loader_ = base::MakeUnique<RankerModelLoader>( |
| + base::Bind(&ValidateModel), |
| + base::Bind(&TranslateRankerImpl::OnModelAvailable, |
| + weak_ptr_factory_.GetWeakPtr()), |
| + model_path, model_url, kUmaPrefix); |
| + model_loader_->Start(); |
| +} |
| + |
| +TranslateRankerImpl::~TranslateRankerImpl() {} |
| + |
| +// static |
| +base::FilePath TranslateRankerImpl::GetModelPath( |
| + const base::FilePath& data_dir) { |
| + if (data_dir.empty()) |
| + return base::FilePath(); |
| + |
| + // Otherwise, look for the file in data dir. |
| + return data_dir.AppendASCII(kTranslateRankerModelFileName); |
| +} |
| + |
| +// static |
| +GURL TranslateRankerImpl::GetModelURL() { |
| + // Allow override of the ranker model URL from the command line. |
| + base::CommandLine* command_line = base::CommandLine::ForCurrentProcess(); |
| + if (command_line->HasSwitch(switches::kTranslateRankerModelURL)) { |
| + return GURL( |
| + command_line->GetSwitchValueASCII(switches::kTranslateRankerModelURL)); |
| + } |
| + |
| + // Otherwise take the ranker model URL from the ranker query variation. |
| + const std::string raw_url = variations::GetVariationParamValueByFeature( |
| + kTranslateRankerQuery, switches::kTranslateRankerModelURL); |
| + |
| + DVLOG(3) << switches::kTranslateRankerModelURL << " = " << raw_url; |
| + |
| + return GURL(raw_url); |
| +} |
| + |
| +bool TranslateRankerImpl::IsLoggingEnabled() { |
| + return base::FeatureList::IsEnabled(kTranslateRankerLogging); |
| +} |
| + |
| +bool TranslateRankerImpl::IsQueryEnabled() { |
| + return base::FeatureList::IsEnabled(kTranslateRankerQuery); |
| +} |
| + |
| +bool TranslateRankerImpl::IsEnforcementEnabled() { |
| + return base::FeatureList::IsEnabled(kTranslateRankerEnforcement); |
| +} |
| + |
| +int TranslateRankerImpl::GetModelVersion() const { |
| + return model_ ? model_->proto().translate().version() : 0; |
| +} |
| + |
| +bool TranslateRankerImpl::ShouldOfferTranslation( |
| + const TranslatePrefs& translate_prefs, |
| + const std::string& src_lang, |
| + const std::string& dst_lang) { |
| + DCHECK(sequence_checker_.CalledOnValidSequence()); |
| + // The ranker is a gate in the "show a translation prompt" flow. To retain |
| + // the pre-existing functionality, it defaults to returning true in the |
| + // absence of a model or if enforcement is disabled. As this is ranker is |
| + // subsumed into a more general assist ranker, this default will go away |
| + // (or become False). |
| + const bool kDefaultResponse = true; |
| + |
| + // If we don't have a model, request one and return the default. |
| + if (model_ == nullptr) { |
| + return kDefaultResponse; |
| + } |
| + |
| + SCOPED_UMA_HISTOGRAM_TIMER("Translate.Ranker.Timer.ShouldOfferTranslation"); |
| + |
| + // TODO(rogerm): Remove ScopedTracker below once crbug.com/646711 is closed. |
| + tracked_objects::ScopedTracker tracking_profile( |
| + FROM_HERE_WITH_EXPLICIT_FUNCTION( |
| + "646711 translate::TranslateRankerImpl::ShouldOfferTranslation")); |
| + |
| + TranslateRankerFeatures features( |
| + translate_prefs, src_lang, dst_lang, |
| + TranslateDownloadManager::GetInstance()->application_locale()); |
| + |
| + double score = CalculateScore(features); |
| + |
| + DVLOG(2) << "TranslateRankerImpl::ShouldOfferTranslation: " |
| + << "Score = " << score << ", Features=[" << features << "]"; |
| + |
| + bool result = (score >= kTranslationOfferThreshold); |
| + |
| + UMA_HISTOGRAM_BOOLEAN("Translate.Ranker.QueryResult", result); |
| + |
| + return result; |
| +} |
| + |
| +double TranslateRankerImpl::CalculateScore( |
| + const TranslateRankerFeatures& features) { |
| + DCHECK(sequence_checker_.CalledOnValidSequence()); |
| + SCOPED_UMA_HISTOGRAM_TIMER("Translate.Ranker.Timer.CalculateScore"); |
| + DCHECK(model_ != nullptr); |
| + const TranslateRankerModel::LogisticRegressionModel& logit = |
| + model_->proto().translate().logistic_regression_model(); |
| + |
| + double dot_product = |
| + (features.accepted_count * logit.accept_count_weight()) + |
| + (features.denied_count * logit.decline_count_weight()) + |
| + (features.ignored_count * logit.ignore_count_weight()) + |
| + (features.accepted_ratio * logit.accept_ratio_weight()) + |
| + (features.denied_ratio * logit.decline_ratio_weight()) + |
| + (features.ignored_ratio * logit.ignore_ratio_weight()) + |
| + ScoreComponent(logit.source_language_weight(), features.src_lang) + |
| + ScoreComponent(logit.dest_language_weight(), features.dst_lang) + |
| + ScoreComponent(logit.country_weight(), features.country) + |
| + ScoreComponent(logit.locale_weight(), features.app_locale); |
| + |
| + return Sigmoid(dot_product + logit.bias()); |
| +} |
| + |
| +void TranslateRankerImpl::FlushTranslateEvents( |
| + std::vector<metrics::TranslateEventProto>* events) { |
| + DCHECK(sequence_checker_.CalledOnValidSequence()); |
| + DVLOG(3) << "Flushing translate ranker events."; |
| + events->swap(event_cache_); |
| + event_cache_.clear(); |
| +} |
| + |
| +void TranslateRankerImpl::AddTranslateEvent( |
| + const metrics::TranslateEventProto& event) { |
| + DCHECK(sequence_checker_.CalledOnValidSequence()); |
| + DVLOG(3) << "Adding translate ranker event."; |
| + if (IsLoggingEnabled()) |
| + event_cache_.push_back(event); |
| +} |
| + |
| +void TranslateRankerImpl::OnModelAvailable(std::unique_ptr<RankerModel> model) { |
| + DCHECK(sequence_checker_.CalledOnValidSequence()); |
| + model_ = std::move(model); |
| +} |
| + |
| +} // namespace translate |
| + |
| +std::ostream& operator<<(std::ostream& stream, |
| + const translate::TranslateRankerFeatures& features) { |
| + features.WriteTo(stream); |
| + return stream; |
| +} |