Index: chrome/browser/safe_browsing/client_side_detection_host.cc |
diff --git a/chrome/browser/safe_browsing/client_side_detection_host.cc b/chrome/browser/safe_browsing/client_side_detection_host.cc |
index 6eeb6f7d60d378e3f65a3ec079a2dce0472dd0cd..aeff3d6415d22ad40b4f37d082c6fc1db6ecbf29 100644 |
--- a/chrome/browser/safe_browsing/client_side_detection_host.cc |
+++ b/chrome/browser/safe_browsing/client_side_detection_host.cc |
@@ -12,6 +12,7 @@ |
#include "base/metrics/histogram.h" |
#include "base/prefs/pref_service.h" |
#include "base/sequenced_task_runner_helpers.h" |
+#include "base/strings/utf_string_conversions.h" |
#include "chrome/browser/browser_process.h" |
#include "chrome/browser/profiles/profile.h" |
#include "chrome/browser/safe_browsing/browser_feature_extractor.h" |
@@ -47,6 +48,8 @@ namespace safe_browsing { |
const int ClientSideDetectionHost::kMaxUrlsPerIP = 20; |
const int ClientSideDetectionHost::kMaxIPsPerBrowse = 200; |
+const char kSafeBrowsingMatchKey[] = "safe_browsing_match"; |
+ |
// This class is instantiated each time a new toplevel URL loads, and |
// asynchronously checks whether the phishing classifier should run for this |
// URL. If so, it notifies the renderer with a StartPhishingDetection IPC. |
@@ -248,8 +251,7 @@ ClientSideDetectionHost::ClientSideDetectionHost(WebContents* tab) |
weak_factory_(this), |
unsafe_unique_page_id_(-1), |
malware_killswitch_on_(false), |
- malware_report_enabled_(false), |
- malware_or_phishing_match_(false) { |
+ malware_report_enabled_(false) { |
DCHECK(tab); |
// Note: csd_service_ and sb_service will be NULL here in testing. |
csd_service_ = g_browser_process->safe_browsing_detection_service(); |
@@ -291,8 +293,6 @@ bool ClientSideDetectionHost::OnMessageReceived(const IPC::Message& message) { |
void ClientSideDetectionHost::DidNavigateMainFrame( |
const content::LoadCommittedDetails& details, |
const content::FrameNavigateParams& params) { |
- malware_or_phishing_match_ = false; |
- |
// TODO(noelutz): move this DCHECK to WebContents and fix all the unit tests |
// that don't call this method on the UI thread. |
// DCHECK(BrowserThread::CurrentlyOn(BrowserThread::UI)); |
@@ -355,6 +355,7 @@ void ClientSideDetectionHost::OnSafeBrowsingHit( |
// Store the unique page ID for later. |
unsafe_unique_page_id_ = |
web_contents()->GetController().GetActiveEntry()->GetUniqueID(); |
+ |
// We also keep the resource around in order to be able to send the |
// malicious URL to the server. |
unsafe_resource_.reset(new SafeBrowsingUIManager::UnsafeResource(resource)); |
@@ -363,7 +364,18 @@ void ClientSideDetectionHost::OnSafeBrowsingHit( |
void ClientSideDetectionHost::OnSafeBrowsingMatch( |
const SafeBrowsingUIManager::UnsafeResource& resource) { |
- malware_or_phishing_match_ = true; |
+ if (!web_contents() || !web_contents()->GetController().GetActiveEntry()) |
+ return; |
+ |
+ // Check that this notification is really for us. |
+ content::RenderViewHost* hit_rvh = content::RenderViewHost::FromID( |
+ resource.render_process_host_id, resource.render_view_id); |
+ if (!hit_rvh || |
+ web_contents() != content::WebContents::FromRenderViewHost(hit_rvh)) |
+ return; |
+ |
+ web_contents()->GetController().GetActiveEntry()->SetExtraData( |
+ kSafeBrowsingMatchKey, base::ASCIIToUTF16("1")); |
} |
scoped_refptr<SafeBrowsingDatabaseManager> |
@@ -372,7 +384,24 @@ ClientSideDetectionHost::database_manager() { |
} |
bool ClientSideDetectionHost::DidPageReceiveSafeBrowsingMatch() const { |
- return malware_or_phishing_match_ || DidShowSBInterstitial(); |
+ if (!web_contents() || !web_contents()->GetController().GetVisibleEntry()) |
+ return false; |
+ |
+ // If an interstitial page is showing, GetVisibleEntry will return the |
+ // transient NavigationEntry for the interstitial. The transient entry |
+ // will not have the flag set, so use the pending entry instead if there |
+ // is one. |
+ NavigationEntry* entry = web_contents()->GetController().GetPendingEntry(); |
+ if (!entry) { |
+ entry = web_contents()->GetController().GetVisibleEntry(); |
+ if (entry->GetPageType() == content::PAGE_TYPE_INTERSTITIAL) |
+ entry = web_contents()->GetController().GetLastCommittedEntry(); |
+ if (!entry) |
+ return false; |
+ } |
+ |
+ base::string16 value; |
+ return entry->GetExtraData(kSafeBrowsingMatchKey, &value); |
} |
void ClientSideDetectionHost::WebContentsDestroyed(WebContents* tab) { |