Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(522)

Side by Side Diff: appengine/monorail/tools/spam/spam.py

Issue 1868553004: Open Source Monorail (Closed) Base URL: https://chromium.googlesource.com/infra/infra.git@master
Patch Set: Rebase Created 4 years, 8 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
« no previous file with comments | « appengine/monorail/tools/normalize-casing.sql ('k') | appengine/monorail/tracker/__init__.py » ('j') | no next file with comments »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
OLDNEW
(Empty)
1 #!/usr/bin/env python
2 # Copyright 2016 The Chromium Authors. All rights reserved.
3 # Use of this source code is govered by a BSD-style
4 # license that can be found in the LICENSE file or at
5 # https://developers.google.com/open-source/licenses/bsd
6
7 """
8 Spam classifier command line tools.
9 Use this command to work with Monorail's Cloud Prediction API
10 Spam classifier models.
11
12 This presumes you already have some csv training data files present
13 in gcs and/or local disk, so run the training example exporter first
14 before trying to train or test models.
15
16 Example: The following command will report the training status of the
17 'android-user' model in the monorail-staging project:
18
19 spam.py -p monorail-staging -m android-user status
20
21 Note that in order for this command to work, you must have a service
22 account credentials file on your machine. Download one from Developer
23 Console -> Credentials -> [service account] -> Generate new JSON key.
24 """
25
26 import argparse
27 import csv
28 import hashlib
29 import httplib2
30 import json
31 import logging
32 import os
33 import random
34 import re
35 import subprocess
36 import sys
37 import tempfile
38 import time
39 import googleapiclient
40
41 from apiclient.discovery import build
42 from oauth2client.client import GoogleCredentials
43
44
45 credentials = GoogleCredentials.get_application_default()
46 service = build(
47 'prediction', 'v1.6', http=httplib2.Http(), credentials=credentials)
48
49 def Status(args):
50 result = service.trainedmodels().get(
51 project=args.project,
52 id=args.model,
53 ).execute()
54 return result
55
56 def List(args):
57 result = service.trainedmodels().list(
58 project=args.project,
59 ).execute()
60 return result
61
62 def Analyze(args):
63 result = service.trainedmodels().analyze(
64 project=args.project,
65 id=args.model,
66 ).execute()
67 return result
68
69 def Train(args):
70 result = service.trainedmodels().insert(
71 project=args.project,
72 body={'id':args.model, 'storageDataLocation': args.training_data}
73 ).execute()
74 return result
75
76 def _Classify(project, model, features):
77 retries = 0
78 while retries < 3:
79 try:
80 result = service.trainedmodels().predict(
81 project=project,
82 id=model,
83 body={'input': {'csvInstance': features}}
84 ).execute()
85 return result
86 except googleapiclient.errors.HttpError as err:
87 retries = retries + 1
88 print ('Error calling prediction API, attempt %d: %s' % (
89 retries, sys.exc_info()[0]))
90 print err.content.decode('utf-8')
91
92 sys.exit(1)
93
94 return result
95
96 def Test(args):
97 with open(args.testing_data, 'rb') as csvfile:
98 spamreader = csv.reader(csvfile)
99 i = 0
100 confusion = {"ham": {"ham": 0, "spam": 0}, "spam": {"ham": 0, "spam": 0}}
101 for row in spamreader:
102 i = i + 1
103 if random.random() > args.sample_rate:
104 continue
105 label = row[0]
106 features = row[1:]
107 result = _Classify(args.project, args.model, features)
108 c = confusion[label][result['outputLabel']]
109 confusion[label][result['outputLabel']] = c + 1
110
111 print "%d: actual: %s / predicted: %s" % (i, label, result['outputLabel'])
112
113 if label != result['outputLabel']:
114 print "Mismatch:"
115 print json.dumps(row, indent=2)
116 print json.dumps(result, indent=2)
117
118 return confusion
119
120
121 class struct(dict):
122 def __getattr__(self, key):
123 return self.get(key)
124 __setattr__ = dict.__setitem__
125 __delattr__ = dict.__delitem__
126
127
128 def ROC(args):
129 # See page 866, Algorithm 1 in
130 # https://ccrma.stanford.edu/workshops/mir2009/references/ROCintro.pdf
131 # Modified to also keep track of the threshold for point labels
132 # when plotting the output.
133
134 instances = []
135 with open(args.testing_data, 'rb') as csvfile:
136 spamreader = csv.reader(csvfile)
137 total_negative, total_positive = 0.0, 0.0
138 i = 0
139 for row in spamreader:
140 i = i + 1
141 if random.random() > args.sample_rate:
142 continue
143 label = row[0]
144 features = row[1:]
145 result = _Classify(args.project, args.model, features)
146 for p in result['outputMulti']:
147 if p['label'] == 'spam':
148 spam_score = float(p['score'])
149
150 if label == 'spam':
151 total_positive += 1
152 else:
153 total_negative += 1
154
155 instances.append(struct(true_class=label, spam_score=spam_score))
156
157 true_positive, false_positive = 0.0, 0.0
158 results = []
159
160 instances.sort(key=lambda i: 1.0 - i.spam_score)
161 score_prev = None
162
163 for i in instances:
164 if score_prev is None or i.spam_score != score_prev:
165 results.append(struct(
166 x=false_positive/total_negative,
167 y=true_positive/total_positive,
168 threshold=i.spam_score))
169 score_prev = i.spam_score
170
171 if i.true_class == 'spam':
172 true_positive += 1
173 else:
174 false_positive += 1
175
176 results.append(struct(
177 x=false_positive/total_negative,
178 y=true_positive/total_positive,
179 threshold=i.spam_score))
180
181 print "False Positive Rate, True Positive Rate, Threshold"
182 for r in results:
183 print "%f, %f, %f" % (r.x, r.y, r.threshold)
184
185 print "FP/N: %f/%f, TP/P: %f/%f" % (
186 false_positive, total_negative, true_positive, total_positive)
187
188 def Prep(args):
189 with open(args.infile, 'rb') as csvfile:
190 with tempfile.NamedTemporaryFile('wb', delete=False) as trainfile:
191 with open(args.test, 'wb') as testfile:
192 for row in csvfile:
193 # If hash features are requested, generate those instead of
194 # the raw text.
195 if args.hash_features > 0:
196 row = row.split(',')
197 # Hash every field after the first (which is the class)
198 feature_hashes = _HashFeatures(row[1:], args.hash_features)
199 # Convert to strings so we can re-join the columns.
200 feature_hashes = [str(h) for h in feature_hashes]
201 row = [row[0]]
202 row.extend(feature_hashes)
203 row = ','.join(row) + '\n'
204
205 if random.random() > args.ratio:
206 testfile.write(row)
207 else:
208 trainfile.write(row)
209
210 print 'Copying %s to The Cloud as %s' % (trainfile.name, args.train)
211 subprocess.check_call(['gsutil', 'cp', trainfile.name, args.train])
212
213 DELIMITERS = ['\s', '\,', '\.', '\?', '!', '\:', '\(', '\)']
214
215 def _HashFeatures(content, num_features):
216 """
217 Feature hashing is a fast and compact way to turn a string of text into a
218 vector of feature values for classification and training.
219 See also: https://en.wikipedia.org/wiki/Feature_hashing
220 This is a simple implementation that doesn't try to minimize collisions
221 or anything else fancy.
222 """
223 features = [0] * num_features
224 for blob in content:
225 words = re.split('|'.join(DELIMITERS), blob)
226 for w in words:
227 feature_index = int(int(hashlib.sha1(w).hexdigest(), 16) % num_features)
228 features[feature_index] += 1
229
230 return features
231
232 def main():
233 if 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ:
234 print ('GOOGLE_APPLICATION_CREDENTIALS environment variable is not set. '
235 'Exiting.')
236 sys.exit(1)
237
238 parser = argparse.ArgumentParser(
239 description='Spam classifier utilities.')
240 parser.add_argument('--project', '-p', default='monorail-staging')
241 subparsers = parser.add_subparsers(dest='command')
242
243 subparsers.add_parser('ls')
244
245 parser_analyze = subparsers.add_parser('analyze')
246 parser_analyze.add_argument('--model', '-m', required=True)
247
248 parser_status = subparsers.add_parser('status')
249 parser_status.add_argument('--model', '-m', required=True)
250
251 parser_test = subparsers.add_parser('test')
252 parser_test.add_argument('--model', '-m', required=True)
253 parser_test.add_argument('--testing_data', '-x',
254 help='Location of local testing csv file, e.g. /tmp/testing.csv')
255 parser_test.add_argument('--sample_rate', '-r', default=0.01,
256 help='Sample rate for classifier testing.')
257
258 parser_roc = subparsers.add_parser('roc',
259 help='Generate a Receiver Operating Characteristic curve')
260 parser_roc.add_argument('--model', '-m', required=True)
261 parser_roc.add_argument('--testing_data', '-x',
262 help='Location of local testing csv file, e.g. /tmp/testing.csv')
263 parser_roc.add_argument('--sample_rate', '-r', type=float, default=0.001,
264 help='Sample rate for classifier testing.', )
265
266 parser_train = subparsers.add_parser('train')
267 parser_train.add_argument('--model', '-m', required=True)
268 parser_train.add_argument('--training_data', '-t',
269 help=('Location of training csv file (omit gs:// prefix), '
270 'e.g. monorail-staging-spam-training-data/train.csv'))
271
272 parser_prep = subparsers.add_parser('prep',
273 help='Split a csv file into training and test')
274 parser_prep.add_argument('--infile', '-i', required=True,
275 help='CSV file with complete set of labeled examples.',)
276 parser_prep.add_argument('--train', required=True,
277 help=('Destination for training csv file, '
278 'e.g. gs://monorail-staging-spam-training-data/train.csv'))
279 parser_prep.add_argument('--test', required=True,
280 help='Destination for training csv file, local filesystem.')
281 parser_prep.add_argument('--ratio', default=0.75,
282 help='Test/train split ratio.')
283 parser_prep.add_argument('--hash_features', '-f', type=int,
284 help='Number of hash features to generate.', default=0)
285
286 args = parser.parse_args()
287
288 cmds = {
289 "ls": List,
290 "analyze": Analyze,
291 "status": Status,
292 "test": Test,
293 "train": Train,
294 "prep": Prep,
295 'roc': ROC,
296 }
297 res = cmds[args.command](args)
298
299 print json.dumps(res, indent=2)
300
301
302 if __name__ == '__main__':
303 main()
OLDNEW
« no previous file with comments | « appengine/monorail/tools/normalize-casing.sql ('k') | appengine/monorail/tracker/__init__.py » ('j') | no next file with comments »

Powered by Google App Engine
This is Rietveld 408576698