Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(342)

Unified Diff: appengine/monorail/tools/spam/spam.py

Issue 1868553004: Open Source Monorail (Closed) Base URL: https://chromium.googlesource.com/infra/infra.git@master
Patch Set: Rebase Created 4 years, 8 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View side-by-side diff with in-line comments
Download patch
« no previous file with comments | « appengine/monorail/tools/normalize-casing.sql ('k') | appengine/monorail/tracker/__init__.py » ('j') | no next file with comments »
Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
Index: appengine/monorail/tools/spam/spam.py
diff --git a/appengine/monorail/tools/spam/spam.py b/appengine/monorail/tools/spam/spam.py
new file mode 100644
index 0000000000000000000000000000000000000000..20c7fd80d2108d770b569aa25a05da6ddabad28f
--- /dev/null
+++ b/appengine/monorail/tools/spam/spam.py
@@ -0,0 +1,303 @@
+#!/usr/bin/env python
+# Copyright 2016 The Chromium Authors. All rights reserved.
+# Use of this source code is govered by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""
+Spam classifier command line tools.
+Use this command to work with Monorail's Cloud Prediction API
+Spam classifier models.
+
+This presumes you already have some csv training data files present
+in gcs and/or local disk, so run the training example exporter first
+before trying to train or test models.
+
+Example: The following command will report the training status of the
+'android-user' model in the monorail-staging project:
+
+spam.py -p monorail-staging -m android-user status
+
+Note that in order for this command to work, you must have a service
+account credentials file on your machine. Download one from Developer
+Console -> Credentials -> [service account] -> Generate new JSON key.
+"""
+
+import argparse
+import csv
+import hashlib
+import httplib2
+import json
+import logging
+import os
+import random
+import re
+import subprocess
+import sys
+import tempfile
+import time
+import googleapiclient
+
+from apiclient.discovery import build
+from oauth2client.client import GoogleCredentials
+
+
+credentials = GoogleCredentials.get_application_default()
+service = build(
+ 'prediction', 'v1.6', http=httplib2.Http(), credentials=credentials)
+
+def Status(args):
+ result = service.trainedmodels().get(
+ project=args.project,
+ id=args.model,
+ ).execute()
+ return result
+
+def List(args):
+ result = service.trainedmodels().list(
+ project=args.project,
+ ).execute()
+ return result
+
+def Analyze(args):
+ result = service.trainedmodels().analyze(
+ project=args.project,
+ id=args.model,
+ ).execute()
+ return result
+
+def Train(args):
+ result = service.trainedmodels().insert(
+ project=args.project,
+ body={'id':args.model, 'storageDataLocation': args.training_data}
+ ).execute()
+ return result
+
+def _Classify(project, model, features):
+ retries = 0
+ while retries < 3:
+ try:
+ result = service.trainedmodels().predict(
+ project=project,
+ id=model,
+ body={'input': {'csvInstance': features}}
+ ).execute()
+ return result
+ except googleapiclient.errors.HttpError as err:
+ retries = retries + 1
+ print ('Error calling prediction API, attempt %d: %s' % (
+ retries, sys.exc_info()[0]))
+ print err.content.decode('utf-8')
+
+ sys.exit(1)
+
+ return result
+
+def Test(args):
+ with open(args.testing_data, 'rb') as csvfile:
+ spamreader = csv.reader(csvfile)
+ i = 0
+ confusion = {"ham": {"ham": 0, "spam": 0}, "spam": {"ham": 0, "spam": 0}}
+ for row in spamreader:
+ i = i + 1
+ if random.random() > args.sample_rate:
+ continue
+ label = row[0]
+ features = row[1:]
+ result = _Classify(args.project, args.model, features)
+ c = confusion[label][result['outputLabel']]
+ confusion[label][result['outputLabel']] = c + 1
+
+ print "%d: actual: %s / predicted: %s" % (i, label, result['outputLabel'])
+
+ if label != result['outputLabel']:
+ print "Mismatch:"
+ print json.dumps(row, indent=2)
+ print json.dumps(result, indent=2)
+
+ return confusion
+
+
+class struct(dict):
+ def __getattr__(self, key):
+ return self.get(key)
+ __setattr__ = dict.__setitem__
+ __delattr__ = dict.__delitem__
+
+
+def ROC(args):
+ # See page 866, Algorithm 1 in
+ # https://ccrma.stanford.edu/workshops/mir2009/references/ROCintro.pdf
+ # Modified to also keep track of the threshold for point labels
+ # when plotting the output.
+
+ instances = []
+ with open(args.testing_data, 'rb') as csvfile:
+ spamreader = csv.reader(csvfile)
+ total_negative, total_positive = 0.0, 0.0
+ i = 0
+ for row in spamreader:
+ i = i + 1
+ if random.random() > args.sample_rate:
+ continue
+ label = row[0]
+ features = row[1:]
+ result = _Classify(args.project, args.model, features)
+ for p in result['outputMulti']:
+ if p['label'] == 'spam':
+ spam_score = float(p['score'])
+
+ if label == 'spam':
+ total_positive += 1
+ else:
+ total_negative += 1
+
+ instances.append(struct(true_class=label, spam_score=spam_score))
+
+ true_positive, false_positive = 0.0, 0.0
+ results = []
+
+ instances.sort(key=lambda i: 1.0 - i.spam_score)
+ score_prev = None
+
+ for i in instances:
+ if score_prev is None or i.spam_score != score_prev:
+ results.append(struct(
+ x=false_positive/total_negative,
+ y=true_positive/total_positive,
+ threshold=i.spam_score))
+ score_prev = i.spam_score
+
+ if i.true_class == 'spam':
+ true_positive += 1
+ else:
+ false_positive += 1
+
+ results.append(struct(
+ x=false_positive/total_negative,
+ y=true_positive/total_positive,
+ threshold=i.spam_score))
+
+ print "False Positive Rate, True Positive Rate, Threshold"
+ for r in results:
+ print "%f, %f, %f" % (r.x, r.y, r.threshold)
+
+ print "FP/N: %f/%f, TP/P: %f/%f" % (
+ false_positive, total_negative, true_positive, total_positive)
+
+def Prep(args):
+ with open(args.infile, 'rb') as csvfile:
+ with tempfile.NamedTemporaryFile('wb', delete=False) as trainfile:
+ with open(args.test, 'wb') as testfile:
+ for row in csvfile:
+ # If hash features are requested, generate those instead of
+ # the raw text.
+ if args.hash_features > 0:
+ row = row.split(',')
+ # Hash every field after the first (which is the class)
+ feature_hashes = _HashFeatures(row[1:], args.hash_features)
+ # Convert to strings so we can re-join the columns.
+ feature_hashes = [str(h) for h in feature_hashes]
+ row = [row[0]]
+ row.extend(feature_hashes)
+ row = ','.join(row) + '\n'
+
+ if random.random() > args.ratio:
+ testfile.write(row)
+ else:
+ trainfile.write(row)
+
+ print 'Copying %s to The Cloud as %s' % (trainfile.name, args.train)
+ subprocess.check_call(['gsutil', 'cp', trainfile.name, args.train])
+
+DELIMITERS = ['\s', '\,', '\.', '\?', '!', '\:', '\(', '\)']
+
+def _HashFeatures(content, num_features):
+ """
+ Feature hashing is a fast and compact way to turn a string of text into a
+ vector of feature values for classification and training.
+ See also: https://en.wikipedia.org/wiki/Feature_hashing
+ This is a simple implementation that doesn't try to minimize collisions
+ or anything else fancy.
+ """
+ features = [0] * num_features
+ for blob in content:
+ words = re.split('|'.join(DELIMITERS), blob)
+ for w in words:
+ feature_index = int(int(hashlib.sha1(w).hexdigest(), 16) % num_features)
+ features[feature_index] += 1
+
+ return features
+
+def main():
+ if 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ:
+ print ('GOOGLE_APPLICATION_CREDENTIALS environment variable is not set. '
+ 'Exiting.')
+ sys.exit(1)
+
+ parser = argparse.ArgumentParser(
+ description='Spam classifier utilities.')
+ parser.add_argument('--project', '-p', default='monorail-staging')
+ subparsers = parser.add_subparsers(dest='command')
+
+ subparsers.add_parser('ls')
+
+ parser_analyze = subparsers.add_parser('analyze')
+ parser_analyze.add_argument('--model', '-m', required=True)
+
+ parser_status = subparsers.add_parser('status')
+ parser_status.add_argument('--model', '-m', required=True)
+
+ parser_test = subparsers.add_parser('test')
+ parser_test.add_argument('--model', '-m', required=True)
+ parser_test.add_argument('--testing_data', '-x',
+ help='Location of local testing csv file, e.g. /tmp/testing.csv')
+ parser_test.add_argument('--sample_rate', '-r', default=0.01,
+ help='Sample rate for classifier testing.')
+
+ parser_roc = subparsers.add_parser('roc',
+ help='Generate a Receiver Operating Characteristic curve')
+ parser_roc.add_argument('--model', '-m', required=True)
+ parser_roc.add_argument('--testing_data', '-x',
+ help='Location of local testing csv file, e.g. /tmp/testing.csv')
+ parser_roc.add_argument('--sample_rate', '-r', type=float, default=0.001,
+ help='Sample rate for classifier testing.', )
+
+ parser_train = subparsers.add_parser('train')
+ parser_train.add_argument('--model', '-m', required=True)
+ parser_train.add_argument('--training_data', '-t',
+ help=('Location of training csv file (omit gs:// prefix), '
+ 'e.g. monorail-staging-spam-training-data/train.csv'))
+
+ parser_prep = subparsers.add_parser('prep',
+ help='Split a csv file into training and test')
+ parser_prep.add_argument('--infile', '-i', required=True,
+ help='CSV file with complete set of labeled examples.',)
+ parser_prep.add_argument('--train', required=True,
+ help=('Destination for training csv file, '
+ 'e.g. gs://monorail-staging-spam-training-data/train.csv'))
+ parser_prep.add_argument('--test', required=True,
+ help='Destination for training csv file, local filesystem.')
+ parser_prep.add_argument('--ratio', default=0.75,
+ help='Test/train split ratio.')
+ parser_prep.add_argument('--hash_features', '-f', type=int,
+ help='Number of hash features to generate.', default=0)
+
+ args = parser.parse_args()
+
+ cmds = {
+ "ls": List,
+ "analyze": Analyze,
+ "status": Status,
+ "test": Test,
+ "train": Train,
+ "prep": Prep,
+ 'roc': ROC,
+ }
+ res = cmds[args.command](args)
+
+ print json.dumps(res, indent=2)
+
+
+if __name__ == '__main__':
+ main()
« no previous file with comments | « appengine/monorail/tools/normalize-casing.sql ('k') | appengine/monorail/tracker/__init__.py » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698