Index: chrome/browser/safe_browsing/browser_feature_extractor_unittest.cc |
diff --git a/chrome/browser/safe_browsing/browser_feature_extractor_unittest.cc b/chrome/browser/safe_browsing/browser_feature_extractor_unittest.cc |
index 163268dffc1a259e5ec777a58c17a16fb49db463..f808aae97f1900016f47d33ce952a6e7c16303bd 100644 |
--- a/chrome/browser/safe_browsing/browser_feature_extractor_unittest.cc |
+++ b/chrome/browser/safe_browsing/browser_feature_extractor_unittest.cc |
@@ -17,7 +17,9 @@ |
#include "chrome/browser/history/history_service_factory.h" |
#include "chrome/browser/profiles/profile.h" |
#include "chrome/browser/safe_browsing/browser_features.h" |
-#include "chrome/browser/safe_browsing/client_side_detection_service.h" |
+#include "chrome/browser/safe_browsing/client_side_detection_host.h" |
+#include "chrome/browser/safe_browsing/database_manager.h" |
+#include "chrome/browser/safe_browsing/safe_browsing_service.h" |
#include "chrome/browser/safe_browsing/ui_manager.h" |
#include "chrome/common/safe_browsing/csd.pb.h" |
#include "chrome/test/base/chrome_render_view_host_test_harness.h" |
@@ -32,18 +34,44 @@ |
#include "testing/gtest/include/gtest/gtest.h" |
#include "url/gurl.h" |
+using content::BrowserThread; |
using content::WebContentsTester; |
+ |
+using testing::DoAll; |
using testing::Return; |
using testing::StrictMock; |
namespace safe_browsing { |
+ |
namespace { |
-class MockClientSideDetectionService : public ClientSideDetectionService { |
+ |
+class MockSafeBrowsingDatabaseManager : public SafeBrowsingDatabaseManager { |
+ public: |
+ explicit MockSafeBrowsingDatabaseManager( |
+ const scoped_refptr<SafeBrowsingService>& service) |
+ : SafeBrowsingDatabaseManager(service) { } |
+ |
+ MOCK_METHOD1(MatchMalwareIP, bool(const std::string& ip_address)); |
+ |
+ protected: |
+ virtual ~MockSafeBrowsingDatabaseManager() {} |
+ |
+ private: |
+ DISALLOW_COPY_AND_ASSIGN(MockSafeBrowsingDatabaseManager); |
+}; |
+ |
+class MockClientSideDetectionHost : public ClientSideDetectionHost { |
public: |
- MockClientSideDetectionService() : ClientSideDetectionService(NULL) {} |
- virtual ~MockClientSideDetectionService() {}; |
+ MockClientSideDetectionHost( |
+ content::WebContents* tab, |
+ SafeBrowsingDatabaseManager* database_manager) |
+ : ClientSideDetectionHost(tab) { |
+ set_safe_browsing_managers(NULL, database_manager); |
+ } |
+ |
+ virtual ~MockClientSideDetectionHost() {}; |
- MOCK_CONST_METHOD1(IsBadIpAddress, bool(const std::string&)); |
+ MOCK_METHOD1(IsBadIpAddress, bool(const std::string&)); |
}; |
} // namespace |
@@ -53,15 +81,21 @@ class BrowserFeatureExtractorTest : public ChromeRenderViewHostTestHarness { |
ChromeRenderViewHostTestHarness::SetUp(); |
ASSERT_TRUE(profile()->CreateHistoryService( |
true /* delete_file */, false /* no_db */)); |
- service_.reset(new StrictMock<MockClientSideDetectionService>()); |
+ |
+ db_manager_ = new StrictMock<MockSafeBrowsingDatabaseManager>( |
+ SafeBrowsingService::CreateSafeBrowsingService()); |
+ host_.reset(new StrictMock<MockClientSideDetectionHost>( |
+ web_contents(), db_manager_.get())); |
extractor_.reset( |
- new BrowserFeatureExtractor(web_contents(), service_.get())); |
+ new BrowserFeatureExtractor(web_contents(), host_.get())); |
num_pending_ = 0; |
browse_info_.reset(new BrowseInfo); |
} |
virtual void TearDown() { |
extractor_.reset(); |
+ host_.reset(); |
+ db_manager_ = NULL; |
profile()->DestroyHistoryService(); |
ChromeRenderViewHostTestHarness::TearDown(); |
ASSERT_EQ(0, num_pending_); |
@@ -138,8 +172,24 @@ class BrowserFeatureExtractorTest : public ChromeRenderViewHostTestHarness { |
} |
void ExtractMalwareFeatures(ClientMalwareRequest* request) { |
+ // Feature extraction takes ownership of the request object |
+ // and passes it along to the done callback in the end. |
+ StartExtractMalwareFeatures(request); |
+ base::MessageLoopForUI::current()->Run(); |
+ EXPECT_EQ(1U, success_.count(request)); |
+ EXPECT_TRUE(success_[request]); |
+ } |
+ |
+ void StartExtractMalwareFeatures(ClientMalwareRequest* request) { |
+ success_.erase(request); |
+ ++num_pending_; |
+ // We temporarily give up ownership of request to ExtractMalwareFeatures |
+ // but we'll regain ownership of it in ExtractMalwareFeaturesDone. |
extractor_->ExtractMalwareFeatures( |
- browse_info_.get(), request); |
+ browse_info_.get(), |
+ request, |
+ base::Bind(&BrowserFeatureExtractorTest::ExtractMalwareFeaturesDone, |
+ base::Unretained(this))); |
} |
void GetMalwareFeatureMap( |
@@ -157,11 +207,12 @@ class BrowserFeatureExtractorTest : public ChromeRenderViewHostTestHarness { |
} |
} |
- int num_pending_; |
+ int num_pending_; // Number of pending feature extractions. |
scoped_ptr<BrowserFeatureExtractor> extractor_; |
- std::map<ClientPhishingRequest*, bool> success_; |
+ std::map<void*, bool> success_; |
scoped_ptr<BrowseInfo> browse_info_; |
- scoped_ptr<MockClientSideDetectionService> service_; |
+ scoped_ptr<StrictMock<MockClientSideDetectionHost> > host_; |
+ scoped_refptr<StrictMock<MockSafeBrowsingDatabaseManager> > db_manager_; |
private: |
void ExtractFeaturesDone(bool success, ClientPhishingRequest* request) { |
@@ -171,6 +222,19 @@ class BrowserFeatureExtractorTest : public ChromeRenderViewHostTestHarness { |
base::MessageLoop::current()->Quit(); |
} |
} |
+ |
+ void ExtractMalwareFeaturesDone( |
+ bool success, |
+ scoped_ptr<ClientMalwareRequest> request) { |
+ EXPECT_TRUE(BrowserThread::CurrentlyOn(BrowserThread::UI)); |
+ ASSERT_EQ(0U, success_.count(request.get())); |
+ // The pointer doesn't really belong to us. It belongs to |
+ // the test case which passed it to ExtractMalwareFeatures above. |
+ success_[request.release()] = success; |
+ if (--num_pending_ == 0) { |
+ base::MessageLoopForUI::current()->Quit(); |
+ } |
+ } |
}; |
TEST_F(BrowserFeatureExtractorTest, UrlNotInHistory) { |
@@ -480,13 +544,6 @@ TEST_F(BrowserFeatureExtractorTest, BrowseFeatures) { |
GURL("https://bankofamerica.com"), |
content::PAGE_TRANSITION_GENERATED); |
- std::set<std::string> urls; |
- urls.insert("http://test.com"); |
- browse_info_->ips.insert(std::make_pair("193.5.163.8", urls)); |
- browse_info_->ips.insert(std::make_pair("23.94.78.1", urls)); |
- EXPECT_CALL(*service_, IsBadIpAddress("193.5.163.8")).WillOnce(Return(true)); |
- EXPECT_CALL(*service_, IsBadIpAddress("23.94.78.1")).WillOnce(Return(false)); |
- |
EXPECT_TRUE(ExtractFeatures(&request)); |
features.clear(); |
GetFeatureMap(request, &features); |
@@ -507,9 +564,6 @@ TEST_F(BrowserFeatureExtractorTest, BrowseFeatures) { |
features::kHostPrefix, |
features::kIsFirstNavigation))); |
EXPECT_EQ(5.0, features[features::kPageTransitionType]); |
- EXPECT_EQ(1.0, features[std::string(features::kBadIpFetch) + "193.5.163.8"]); |
- EXPECT_FALSE(features.count(std::string(features::kBadIpFetch) + |
- "23.94.78.1")); |
} |
TEST_F(BrowserFeatureExtractorTest, SafeBrowsingFeatures) { |
@@ -552,9 +606,12 @@ TEST_F(BrowserFeatureExtractorTest, MalwareFeatures) { |
std::set<std::string> good_urls; |
good_urls.insert("http://ok.com"); |
browse_info_->ips.insert(std::make_pair("23.94.78.1", good_urls)); |
- EXPECT_CALL(*service_, IsBadIpAddress("193.5.163.8")).WillOnce(Return(true)); |
- EXPECT_CALL(*service_, IsBadIpAddress("92.92.92.92")).WillOnce(Return(true)); |
- EXPECT_CALL(*service_, IsBadIpAddress("23.94.78.1")).WillOnce(Return(false)); |
+ EXPECT_CALL(*db_manager_, MatchMalwareIP("193.5.163.8")) |
+ .WillOnce(Return(true)); |
+ EXPECT_CALL(*db_manager_, MatchMalwareIP("92.92.92.92")) |
+ .WillOnce(Return(true)); |
+ EXPECT_CALL(*db_manager_, MatchMalwareIP("23.94.78.1")) |
+ .WillOnce(Return(false)); |
ExtractMalwareFeatures(&request); |
std::map<std::string, std::set<std::string> > features; |
@@ -588,20 +645,16 @@ TEST_F(BrowserFeatureExtractorTest, MalwareFeatures_ExceedLimit) { |
std::string ip = base::StringPrintf("%d.%d.%d.%d", i, i, i, i); |
ips.push_back(ip); |
browse_info_->ips.insert(std::make_pair(ip, bad_urls)); |
- } |
- // First ip is good, then check the next 5 bad ips. |
- // Not check the 7th as reached limit. |
- EXPECT_CALL(*service_, IsBadIpAddress(ips[0])).WillOnce(Return(false)); |
- for (int i = 1; i < 6; ++i) { |
- EXPECT_CALL(*service_, IsBadIpAddress(ips[i])).WillOnce(Return(true)); |
+ // First ip is good but all the others are bad. |
+ EXPECT_CALL(*db_manager_, MatchMalwareIP(ip)).WillOnce(Return(i > 0)); |
} |
ExtractMalwareFeatures(&request); |
std::map<std::string, std::set<std::string> > features; |
GetMalwareFeatureMap(request, &features); |
- // Only keep 5 ips. |
+ // The number of IP match features we store is capped at 5 IPs per request. |
EXPECT_EQ(5U, features.size()); |
} |