OLD | NEW |
(Empty) | |
| 1 #!/usr/bin/env python |
| 2 # Copyright 2016 The Chromium Authors. All rights reserved. |
| 3 # Use of this source code is govered by a BSD-style |
| 4 # license that can be found in the LICENSE file or at |
| 5 # https://developers.google.com/open-source/licenses/bsd |
| 6 |
| 7 """ |
| 8 Spam classifier command line tools. |
| 9 Use this command to work with Monorail's Cloud Prediction API |
| 10 Spam classifier models. |
| 11 |
| 12 This presumes you already have some csv training data files present |
| 13 in gcs and/or local disk, so run the training example exporter first |
| 14 before trying to train or test models. |
| 15 |
| 16 Example: The following command will report the training status of the |
| 17 'android-user' model in the monorail-staging project: |
| 18 |
| 19 spam.py -p monorail-staging -m android-user status |
| 20 |
| 21 Note that in order for this command to work, you must have a service |
| 22 account credentials file on your machine. Download one from Developer |
| 23 Console -> Credentials -> [service account] -> Generate new JSON key. |
| 24 """ |
| 25 |
| 26 import argparse |
| 27 import csv |
| 28 import hashlib |
| 29 import httplib2 |
| 30 import json |
| 31 import logging |
| 32 import os |
| 33 import random |
| 34 import re |
| 35 import subprocess |
| 36 import sys |
| 37 import tempfile |
| 38 import time |
| 39 import googleapiclient |
| 40 |
| 41 from apiclient.discovery import build |
| 42 from oauth2client.client import GoogleCredentials |
| 43 |
| 44 |
| 45 credentials = GoogleCredentials.get_application_default() |
| 46 service = build( |
| 47 'prediction', 'v1.6', http=httplib2.Http(), credentials=credentials) |
| 48 |
| 49 def Status(args): |
| 50 result = service.trainedmodels().get( |
| 51 project=args.project, |
| 52 id=args.model, |
| 53 ).execute() |
| 54 return result |
| 55 |
| 56 def List(args): |
| 57 result = service.trainedmodels().list( |
| 58 project=args.project, |
| 59 ).execute() |
| 60 return result |
| 61 |
| 62 def Analyze(args): |
| 63 result = service.trainedmodels().analyze( |
| 64 project=args.project, |
| 65 id=args.model, |
| 66 ).execute() |
| 67 return result |
| 68 |
| 69 def Train(args): |
| 70 result = service.trainedmodels().insert( |
| 71 project=args.project, |
| 72 body={'id':args.model, 'storageDataLocation': args.training_data} |
| 73 ).execute() |
| 74 return result |
| 75 |
| 76 def _Classify(project, model, features): |
| 77 retries = 0 |
| 78 while retries < 3: |
| 79 try: |
| 80 result = service.trainedmodels().predict( |
| 81 project=project, |
| 82 id=model, |
| 83 body={'input': {'csvInstance': features}} |
| 84 ).execute() |
| 85 return result |
| 86 except googleapiclient.errors.HttpError as err: |
| 87 retries = retries + 1 |
| 88 print ('Error calling prediction API, attempt %d: %s' % ( |
| 89 retries, sys.exc_info()[0])) |
| 90 print err.content.decode('utf-8') |
| 91 |
| 92 sys.exit(1) |
| 93 |
| 94 return result |
| 95 |
| 96 def Test(args): |
| 97 with open(args.testing_data, 'rb') as csvfile: |
| 98 spamreader = csv.reader(csvfile) |
| 99 i = 0 |
| 100 confusion = {"ham": {"ham": 0, "spam": 0}, "spam": {"ham": 0, "spam": 0}} |
| 101 for row in spamreader: |
| 102 i = i + 1 |
| 103 if random.random() > args.sample_rate: |
| 104 continue |
| 105 label = row[0] |
| 106 features = row[1:] |
| 107 result = _Classify(args.project, args.model, features) |
| 108 c = confusion[label][result['outputLabel']] |
| 109 confusion[label][result['outputLabel']] = c + 1 |
| 110 |
| 111 print "%d: actual: %s / predicted: %s" % (i, label, result['outputLabel']) |
| 112 |
| 113 if label != result['outputLabel']: |
| 114 print "Mismatch:" |
| 115 print json.dumps(row, indent=2) |
| 116 print json.dumps(result, indent=2) |
| 117 |
| 118 return confusion |
| 119 |
| 120 |
| 121 class struct(dict): |
| 122 def __getattr__(self, key): |
| 123 return self.get(key) |
| 124 __setattr__ = dict.__setitem__ |
| 125 __delattr__ = dict.__delitem__ |
| 126 |
| 127 |
| 128 def ROC(args): |
| 129 # See page 866, Algorithm 1 in |
| 130 # https://ccrma.stanford.edu/workshops/mir2009/references/ROCintro.pdf |
| 131 # Modified to also keep track of the threshold for point labels |
| 132 # when plotting the output. |
| 133 |
| 134 instances = [] |
| 135 with open(args.testing_data, 'rb') as csvfile: |
| 136 spamreader = csv.reader(csvfile) |
| 137 total_negative, total_positive = 0.0, 0.0 |
| 138 i = 0 |
| 139 for row in spamreader: |
| 140 i = i + 1 |
| 141 if random.random() > args.sample_rate: |
| 142 continue |
| 143 label = row[0] |
| 144 features = row[1:] |
| 145 result = _Classify(args.project, args.model, features) |
| 146 for p in result['outputMulti']: |
| 147 if p['label'] == 'spam': |
| 148 spam_score = float(p['score']) |
| 149 |
| 150 if label == 'spam': |
| 151 total_positive += 1 |
| 152 else: |
| 153 total_negative += 1 |
| 154 |
| 155 instances.append(struct(true_class=label, spam_score=spam_score)) |
| 156 |
| 157 true_positive, false_positive = 0.0, 0.0 |
| 158 results = [] |
| 159 |
| 160 instances.sort(key=lambda i: 1.0 - i.spam_score) |
| 161 score_prev = None |
| 162 |
| 163 for i in instances: |
| 164 if score_prev is None or i.spam_score != score_prev: |
| 165 results.append(struct( |
| 166 x=false_positive/total_negative, |
| 167 y=true_positive/total_positive, |
| 168 threshold=i.spam_score)) |
| 169 score_prev = i.spam_score |
| 170 |
| 171 if i.true_class == 'spam': |
| 172 true_positive += 1 |
| 173 else: |
| 174 false_positive += 1 |
| 175 |
| 176 results.append(struct( |
| 177 x=false_positive/total_negative, |
| 178 y=true_positive/total_positive, |
| 179 threshold=i.spam_score)) |
| 180 |
| 181 print "False Positive Rate, True Positive Rate, Threshold" |
| 182 for r in results: |
| 183 print "%f, %f, %f" % (r.x, r.y, r.threshold) |
| 184 |
| 185 print "FP/N: %f/%f, TP/P: %f/%f" % ( |
| 186 false_positive, total_negative, true_positive, total_positive) |
| 187 |
| 188 def Prep(args): |
| 189 with open(args.infile, 'rb') as csvfile: |
| 190 with tempfile.NamedTemporaryFile('wb', delete=False) as trainfile: |
| 191 with open(args.test, 'wb') as testfile: |
| 192 for row in csvfile: |
| 193 # If hash features are requested, generate those instead of |
| 194 # the raw text. |
| 195 if args.hash_features > 0: |
| 196 row = row.split(',') |
| 197 # Hash every field after the first (which is the class) |
| 198 feature_hashes = _HashFeatures(row[1:], args.hash_features) |
| 199 # Convert to strings so we can re-join the columns. |
| 200 feature_hashes = [str(h) for h in feature_hashes] |
| 201 row = [row[0]] |
| 202 row.extend(feature_hashes) |
| 203 row = ','.join(row) + '\n' |
| 204 |
| 205 if random.random() > args.ratio: |
| 206 testfile.write(row) |
| 207 else: |
| 208 trainfile.write(row) |
| 209 |
| 210 print 'Copying %s to The Cloud as %s' % (trainfile.name, args.train) |
| 211 subprocess.check_call(['gsutil', 'cp', trainfile.name, args.train]) |
| 212 |
| 213 DELIMITERS = ['\s', '\,', '\.', '\?', '!', '\:', '\(', '\)'] |
| 214 |
| 215 def _HashFeatures(content, num_features): |
| 216 """ |
| 217 Feature hashing is a fast and compact way to turn a string of text into a |
| 218 vector of feature values for classification and training. |
| 219 See also: https://en.wikipedia.org/wiki/Feature_hashing |
| 220 This is a simple implementation that doesn't try to minimize collisions |
| 221 or anything else fancy. |
| 222 """ |
| 223 features = [0] * num_features |
| 224 for blob in content: |
| 225 words = re.split('|'.join(DELIMITERS), blob) |
| 226 for w in words: |
| 227 feature_index = int(int(hashlib.sha1(w).hexdigest(), 16) % num_features) |
| 228 features[feature_index] += 1 |
| 229 |
| 230 return features |
| 231 |
| 232 def main(): |
| 233 if 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ: |
| 234 print ('GOOGLE_APPLICATION_CREDENTIALS environment variable is not set. ' |
| 235 'Exiting.') |
| 236 sys.exit(1) |
| 237 |
| 238 parser = argparse.ArgumentParser( |
| 239 description='Spam classifier utilities.') |
| 240 parser.add_argument('--project', '-p', default='monorail-staging') |
| 241 subparsers = parser.add_subparsers(dest='command') |
| 242 |
| 243 subparsers.add_parser('ls') |
| 244 |
| 245 parser_analyze = subparsers.add_parser('analyze') |
| 246 parser_analyze.add_argument('--model', '-m', required=True) |
| 247 |
| 248 parser_status = subparsers.add_parser('status') |
| 249 parser_status.add_argument('--model', '-m', required=True) |
| 250 |
| 251 parser_test = subparsers.add_parser('test') |
| 252 parser_test.add_argument('--model', '-m', required=True) |
| 253 parser_test.add_argument('--testing_data', '-x', |
| 254 help='Location of local testing csv file, e.g. /tmp/testing.csv') |
| 255 parser_test.add_argument('--sample_rate', '-r', default=0.01, |
| 256 help='Sample rate for classifier testing.') |
| 257 |
| 258 parser_roc = subparsers.add_parser('roc', |
| 259 help='Generate a Receiver Operating Characteristic curve') |
| 260 parser_roc.add_argument('--model', '-m', required=True) |
| 261 parser_roc.add_argument('--testing_data', '-x', |
| 262 help='Location of local testing csv file, e.g. /tmp/testing.csv') |
| 263 parser_roc.add_argument('--sample_rate', '-r', type=float, default=0.001, |
| 264 help='Sample rate for classifier testing.', ) |
| 265 |
| 266 parser_train = subparsers.add_parser('train') |
| 267 parser_train.add_argument('--model', '-m', required=True) |
| 268 parser_train.add_argument('--training_data', '-t', |
| 269 help=('Location of training csv file (omit gs:// prefix), ' |
| 270 'e.g. monorail-staging-spam-training-data/train.csv')) |
| 271 |
| 272 parser_prep = subparsers.add_parser('prep', |
| 273 help='Split a csv file into training and test') |
| 274 parser_prep.add_argument('--infile', '-i', required=True, |
| 275 help='CSV file with complete set of labeled examples.',) |
| 276 parser_prep.add_argument('--train', required=True, |
| 277 help=('Destination for training csv file, ' |
| 278 'e.g. gs://monorail-staging-spam-training-data/train.csv')) |
| 279 parser_prep.add_argument('--test', required=True, |
| 280 help='Destination for training csv file, local filesystem.') |
| 281 parser_prep.add_argument('--ratio', default=0.75, |
| 282 help='Test/train split ratio.') |
| 283 parser_prep.add_argument('--hash_features', '-f', type=int, |
| 284 help='Number of hash features to generate.', default=0) |
| 285 |
| 286 args = parser.parse_args() |
| 287 |
| 288 cmds = { |
| 289 "ls": List, |
| 290 "analyze": Analyze, |
| 291 "status": Status, |
| 292 "test": Test, |
| 293 "train": Train, |
| 294 "prep": Prep, |
| 295 'roc': ROC, |
| 296 } |
| 297 res = cmds[args.command](args) |
| 298 |
| 299 print json.dumps(res, indent=2) |
| 300 |
| 301 |
| 302 if __name__ == '__main__': |
| 303 main() |
OLD | NEW |