| Index: appengine/monorail/tools/spam/spam.py
|
| diff --git a/appengine/monorail/tools/spam/spam.py b/appengine/monorail/tools/spam/spam.py
|
| new file mode 100644
|
| index 0000000000000000000000000000000000000000..20c7fd80d2108d770b569aa25a05da6ddabad28f
|
| --- /dev/null
|
| +++ b/appengine/monorail/tools/spam/spam.py
|
| @@ -0,0 +1,303 @@
|
| +#!/usr/bin/env python
|
| +# Copyright 2016 The Chromium Authors. All rights reserved.
|
| +# Use of this source code is govered by a BSD-style
|
| +# license that can be found in the LICENSE file or at
|
| +# https://developers.google.com/open-source/licenses/bsd
|
| +
|
| +"""
|
| +Spam classifier command line tools.
|
| +Use this command to work with Monorail's Cloud Prediction API
|
| +Spam classifier models.
|
| +
|
| +This presumes you already have some csv training data files present
|
| +in gcs and/or local disk, so run the training example exporter first
|
| +before trying to train or test models.
|
| +
|
| +Example: The following command will report the training status of the
|
| +'android-user' model in the monorail-staging project:
|
| +
|
| +spam.py -p monorail-staging -m android-user status
|
| +
|
| +Note that in order for this command to work, you must have a service
|
| +account credentials file on your machine. Download one from Developer
|
| +Console -> Credentials -> [service account] -> Generate new JSON key.
|
| +"""
|
| +
|
| +import argparse
|
| +import csv
|
| +import hashlib
|
| +import httplib2
|
| +import json
|
| +import logging
|
| +import os
|
| +import random
|
| +import re
|
| +import subprocess
|
| +import sys
|
| +import tempfile
|
| +import time
|
| +import googleapiclient
|
| +
|
| +from apiclient.discovery import build
|
| +from oauth2client.client import GoogleCredentials
|
| +
|
| +
|
| +credentials = GoogleCredentials.get_application_default()
|
| +service = build(
|
| + 'prediction', 'v1.6', http=httplib2.Http(), credentials=credentials)
|
| +
|
| +def Status(args):
|
| + result = service.trainedmodels().get(
|
| + project=args.project,
|
| + id=args.model,
|
| + ).execute()
|
| + return result
|
| +
|
| +def List(args):
|
| + result = service.trainedmodels().list(
|
| + project=args.project,
|
| + ).execute()
|
| + return result
|
| +
|
| +def Analyze(args):
|
| + result = service.trainedmodels().analyze(
|
| + project=args.project,
|
| + id=args.model,
|
| + ).execute()
|
| + return result
|
| +
|
| +def Train(args):
|
| + result = service.trainedmodels().insert(
|
| + project=args.project,
|
| + body={'id':args.model, 'storageDataLocation': args.training_data}
|
| + ).execute()
|
| + return result
|
| +
|
| +def _Classify(project, model, features):
|
| + retries = 0
|
| + while retries < 3:
|
| + try:
|
| + result = service.trainedmodels().predict(
|
| + project=project,
|
| + id=model,
|
| + body={'input': {'csvInstance': features}}
|
| + ).execute()
|
| + return result
|
| + except googleapiclient.errors.HttpError as err:
|
| + retries = retries + 1
|
| + print ('Error calling prediction API, attempt %d: %s' % (
|
| + retries, sys.exc_info()[0]))
|
| + print err.content.decode('utf-8')
|
| +
|
| + sys.exit(1)
|
| +
|
| + return result
|
| +
|
| +def Test(args):
|
| + with open(args.testing_data, 'rb') as csvfile:
|
| + spamreader = csv.reader(csvfile)
|
| + i = 0
|
| + confusion = {"ham": {"ham": 0, "spam": 0}, "spam": {"ham": 0, "spam": 0}}
|
| + for row in spamreader:
|
| + i = i + 1
|
| + if random.random() > args.sample_rate:
|
| + continue
|
| + label = row[0]
|
| + features = row[1:]
|
| + result = _Classify(args.project, args.model, features)
|
| + c = confusion[label][result['outputLabel']]
|
| + confusion[label][result['outputLabel']] = c + 1
|
| +
|
| + print "%d: actual: %s / predicted: %s" % (i, label, result['outputLabel'])
|
| +
|
| + if label != result['outputLabel']:
|
| + print "Mismatch:"
|
| + print json.dumps(row, indent=2)
|
| + print json.dumps(result, indent=2)
|
| +
|
| + return confusion
|
| +
|
| +
|
| +class struct(dict):
|
| + def __getattr__(self, key):
|
| + return self.get(key)
|
| + __setattr__ = dict.__setitem__
|
| + __delattr__ = dict.__delitem__
|
| +
|
| +
|
| +def ROC(args):
|
| + # See page 866, Algorithm 1 in
|
| + # https://ccrma.stanford.edu/workshops/mir2009/references/ROCintro.pdf
|
| + # Modified to also keep track of the threshold for point labels
|
| + # when plotting the output.
|
| +
|
| + instances = []
|
| + with open(args.testing_data, 'rb') as csvfile:
|
| + spamreader = csv.reader(csvfile)
|
| + total_negative, total_positive = 0.0, 0.0
|
| + i = 0
|
| + for row in spamreader:
|
| + i = i + 1
|
| + if random.random() > args.sample_rate:
|
| + continue
|
| + label = row[0]
|
| + features = row[1:]
|
| + result = _Classify(args.project, args.model, features)
|
| + for p in result['outputMulti']:
|
| + if p['label'] == 'spam':
|
| + spam_score = float(p['score'])
|
| +
|
| + if label == 'spam':
|
| + total_positive += 1
|
| + else:
|
| + total_negative += 1
|
| +
|
| + instances.append(struct(true_class=label, spam_score=spam_score))
|
| +
|
| + true_positive, false_positive = 0.0, 0.0
|
| + results = []
|
| +
|
| + instances.sort(key=lambda i: 1.0 - i.spam_score)
|
| + score_prev = None
|
| +
|
| + for i in instances:
|
| + if score_prev is None or i.spam_score != score_prev:
|
| + results.append(struct(
|
| + x=false_positive/total_negative,
|
| + y=true_positive/total_positive,
|
| + threshold=i.spam_score))
|
| + score_prev = i.spam_score
|
| +
|
| + if i.true_class == 'spam':
|
| + true_positive += 1
|
| + else:
|
| + false_positive += 1
|
| +
|
| + results.append(struct(
|
| + x=false_positive/total_negative,
|
| + y=true_positive/total_positive,
|
| + threshold=i.spam_score))
|
| +
|
| + print "False Positive Rate, True Positive Rate, Threshold"
|
| + for r in results:
|
| + print "%f, %f, %f" % (r.x, r.y, r.threshold)
|
| +
|
| + print "FP/N: %f/%f, TP/P: %f/%f" % (
|
| + false_positive, total_negative, true_positive, total_positive)
|
| +
|
| +def Prep(args):
|
| + with open(args.infile, 'rb') as csvfile:
|
| + with tempfile.NamedTemporaryFile('wb', delete=False) as trainfile:
|
| + with open(args.test, 'wb') as testfile:
|
| + for row in csvfile:
|
| + # If hash features are requested, generate those instead of
|
| + # the raw text.
|
| + if args.hash_features > 0:
|
| + row = row.split(',')
|
| + # Hash every field after the first (which is the class)
|
| + feature_hashes = _HashFeatures(row[1:], args.hash_features)
|
| + # Convert to strings so we can re-join the columns.
|
| + feature_hashes = [str(h) for h in feature_hashes]
|
| + row = [row[0]]
|
| + row.extend(feature_hashes)
|
| + row = ','.join(row) + '\n'
|
| +
|
| + if random.random() > args.ratio:
|
| + testfile.write(row)
|
| + else:
|
| + trainfile.write(row)
|
| +
|
| + print 'Copying %s to The Cloud as %s' % (trainfile.name, args.train)
|
| + subprocess.check_call(['gsutil', 'cp', trainfile.name, args.train])
|
| +
|
| +DELIMITERS = ['\s', '\,', '\.', '\?', '!', '\:', '\(', '\)']
|
| +
|
| +def _HashFeatures(content, num_features):
|
| + """
|
| + Feature hashing is a fast and compact way to turn a string of text into a
|
| + vector of feature values for classification and training.
|
| + See also: https://en.wikipedia.org/wiki/Feature_hashing
|
| + This is a simple implementation that doesn't try to minimize collisions
|
| + or anything else fancy.
|
| + """
|
| + features = [0] * num_features
|
| + for blob in content:
|
| + words = re.split('|'.join(DELIMITERS), blob)
|
| + for w in words:
|
| + feature_index = int(int(hashlib.sha1(w).hexdigest(), 16) % num_features)
|
| + features[feature_index] += 1
|
| +
|
| + return features
|
| +
|
| +def main():
|
| + if 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ:
|
| + print ('GOOGLE_APPLICATION_CREDENTIALS environment variable is not set. '
|
| + 'Exiting.')
|
| + sys.exit(1)
|
| +
|
| + parser = argparse.ArgumentParser(
|
| + description='Spam classifier utilities.')
|
| + parser.add_argument('--project', '-p', default='monorail-staging')
|
| + subparsers = parser.add_subparsers(dest='command')
|
| +
|
| + subparsers.add_parser('ls')
|
| +
|
| + parser_analyze = subparsers.add_parser('analyze')
|
| + parser_analyze.add_argument('--model', '-m', required=True)
|
| +
|
| + parser_status = subparsers.add_parser('status')
|
| + parser_status.add_argument('--model', '-m', required=True)
|
| +
|
| + parser_test = subparsers.add_parser('test')
|
| + parser_test.add_argument('--model', '-m', required=True)
|
| + parser_test.add_argument('--testing_data', '-x',
|
| + help='Location of local testing csv file, e.g. /tmp/testing.csv')
|
| + parser_test.add_argument('--sample_rate', '-r', default=0.01,
|
| + help='Sample rate for classifier testing.')
|
| +
|
| + parser_roc = subparsers.add_parser('roc',
|
| + help='Generate a Receiver Operating Characteristic curve')
|
| + parser_roc.add_argument('--model', '-m', required=True)
|
| + parser_roc.add_argument('--testing_data', '-x',
|
| + help='Location of local testing csv file, e.g. /tmp/testing.csv')
|
| + parser_roc.add_argument('--sample_rate', '-r', type=float, default=0.001,
|
| + help='Sample rate for classifier testing.', )
|
| +
|
| + parser_train = subparsers.add_parser('train')
|
| + parser_train.add_argument('--model', '-m', required=True)
|
| + parser_train.add_argument('--training_data', '-t',
|
| + help=('Location of training csv file (omit gs:// prefix), '
|
| + 'e.g. monorail-staging-spam-training-data/train.csv'))
|
| +
|
| + parser_prep = subparsers.add_parser('prep',
|
| + help='Split a csv file into training and test')
|
| + parser_prep.add_argument('--infile', '-i', required=True,
|
| + help='CSV file with complete set of labeled examples.',)
|
| + parser_prep.add_argument('--train', required=True,
|
| + help=('Destination for training csv file, '
|
| + 'e.g. gs://monorail-staging-spam-training-data/train.csv'))
|
| + parser_prep.add_argument('--test', required=True,
|
| + help='Destination for training csv file, local filesystem.')
|
| + parser_prep.add_argument('--ratio', default=0.75,
|
| + help='Test/train split ratio.')
|
| + parser_prep.add_argument('--hash_features', '-f', type=int,
|
| + help='Number of hash features to generate.', default=0)
|
| +
|
| + args = parser.parse_args()
|
| +
|
| + cmds = {
|
| + "ls": List,
|
| + "analyze": Analyze,
|
| + "status": Status,
|
| + "test": Test,
|
| + "train": Train,
|
| + "prep": Prep,
|
| + 'roc': ROC,
|
| + }
|
| + res = cmds[args.command](args)
|
| +
|
| + print json.dumps(res, indent=2)
|
| +
|
| +
|
| +if __name__ == '__main__':
|
| + main()
|
|
|