| OLD | NEW |
| (Empty) | |
| 1 # Copyright 2016 The Chromium Authors. All rights reserved. |
| 2 # Use of this source code is governed by a BSD-style license that can be |
| 3 # found in the LICENSE file. |
| 4 |
| 5 import unittest |
| 6 import random |
| 7 import numpy as np |
| 8 |
| 9 import crash.loglinear.test.model_test as loglinear_test |
| 10 from crash.loglinear.model import ToFeatureFunction |
| 11 from crash.loglinear.training import TrainableLogLinearModel |
| 12 |
| 13 |
| 14 training_data = [(x, x == 7) for x in xrange(10)] |
| 15 |
| 16 |
| 17 class TrainableLogLinearModelTest(unittest.TestCase): |
| 18 |
| 19 def setUp(self): |
| 20 super(TrainableLogLinearModelTest, self).setUp() |
| 21 feature_function = ToFeatureFunction(loglinear_test.features) |
| 22 initial_weights = [random.random() for _ in loglinear_test.features] |
| 23 self.model = TrainableLogLinearModel( |
| 24 loglinear_test.Y, training_data, feature_function, initial_weights) |
| 25 |
| 26 def testWeightsSetterNotAnNdarray(self): |
| 27 def _WeightSettingExpression(): |
| 28 """Wrap the ``self.model.weights = stuff`` expression. |
| 29 |
| 30 The ``assertRaises`` method expects a callable object, so we need |
| 31 to wrap the expression in a def. If we didn't wrap it in a def |
| 32 then we'd throw the exception too early, and ``assertRaises`` |
| 33 would never get called in order to see it. Normally we'd use a |
| 34 lambda for wrapping the expression up, but because the expression |
| 35 we want to check is actually a statement it can't be in a lambda |
| 36 but rather must be in a def. |
| 37 """ |
| 38 self.model.weights = 'this is not an np.ndarray' |
| 39 |
| 40 self.assertRaises(TypeError, _WeightSettingExpression) |
| 41 |
| 42 def testWeightsSetterShapeMismatch(self): |
| 43 def _WeightSettingExpression(): |
| 44 """Wrap the ``self.model.weights = stuff`` expression.""" |
| 45 # This np.ndarray has the wrong shape. |
| 46 self.model.weights = np.array([[1,2], [3,4]]) |
| 47 |
| 48 self.assertRaises(TypeError, _WeightSettingExpression) |
| 49 |
| 50 def testTrainWeights(self): |
| 51 """Tests that ``TrainWeights`` actually improves the loglikelihood. |
| 52 |
| 53 Actually, this is more of a test that we're calling SciPy's BFGS |
| 54 implementation correctly. But any bugs we find about that will show |
| 55 up in trying to run this rest rather than in the assertaion failing |
| 56 per se. |
| 57 """ |
| 58 initial_loglikelihood = self.model.LogLikelihood() |
| 59 self.model.TrainWeights(0.5) |
| 60 trained_loglikelihood = self.model.LogLikelihood() |
| 61 self.assertTrue(trained_loglikelihood >= initial_loglikelihood, |
| 62 'Training reduced the loglikelihood from %f to %f,' |
| 63 ' when it should have increased it!' |
| 64 % (initial_loglikelihood, trained_loglikelihood)) |
| OLD | NEW |