| OLD | NEW |
| 1 # Copyright 2016 The Chromium Authors. All rights reserved. | 1 # Copyright 2016 The Chromium Authors. All rights reserved. |
| 2 # Use of this source code is governed by a BSD-style license that can be | 2 # Use of this source code is governed by a BSD-style license that can be |
| 3 # found in the LICENSE file. | 3 # found in the LICENSE file. |
| 4 | 4 |
| 5 import numpy as np | 5 import numpy as np |
| 6 | 6 |
| 7 from crash.loglinear.model import ToFeatureFunction | |
| 8 from crash.loglinear.training import TrainableLogLinearModel | 7 from crash.loglinear.training import TrainableLogLinearModel |
| 9 from crash.loglinear.test.loglinear_testcase import LoglinearTestCase | 8 from crash.loglinear.test.loglinear_testcase import LoglinearTestCase |
| 10 | 9 |
| 11 | 10 |
| 12 class TrainableLogLinearModelTest(LoglinearTestCase): | 11 class TrainableLogLinearModelTest(LoglinearTestCase): |
| 13 | 12 |
| 14 def setUp(self): | 13 def setUp(self): |
| 15 super(TrainableLogLinearModelTest, self).setUp() | 14 super(TrainableLogLinearModelTest, self).setUp() |
| 16 # Normally we wouldn't have *all* possible training data. But this | 15 # Normally we wouldn't have *all* possible training data. But this |
| 17 # is just a test; if it doesn't work now, it'll never work. | 16 # is just a test; if it doesn't work now, it'll never work. |
| 18 training_data = [(x, x == 7) for x in self._X] | 17 training_data = [(x, x == 7) for x in self._X] |
| 19 self._model = TrainableLogLinearModel( | 18 self._model = TrainableLogLinearModel( |
| 20 self._Y, training_data, self._feature_function, self._weights) | 19 self._Y, training_data, self._feature_function, self._weights) |
| 21 | 20 |
| 22 def testWeightsSetterNotAnNdarray(self): | 21 def testNpWeightsSetterNotAnNdarray(self): |
| 23 def _WeightSettingExpression(): | 22 def _NpWeightSettingExpression(): |
| 24 """Wrap the ``self._model.weights = stuff`` expression. | 23 """Wrap the ``self._model.np_weights = stuff`` expression. |
| 25 | 24 |
| 26 The ``assertRaises`` method expects a callable object, so we need | 25 The ``assertRaises`` method expects a callable object, so we need |
| 27 to wrap the expression in a def. If we didn't wrap it in a def | 26 to wrap the expression in a def. If we didn't wrap it in a def |
| 28 then we'd throw the exception too early, and ``assertRaises`` | 27 then we'd throw the exception too early, and ``assertRaises`` |
| 29 would never get called in order to see it. Normally we'd use a | 28 would never get called in order to see it. Normally we'd use a |
| 30 lambda for wrapping the expression up, but because the expression | 29 lambda for wrapping the expression up, but because the expression |
| 31 we want to check is actually a statement it can't be in a lambda | 30 we want to check is actually a statement it can't be in a lambda |
| 32 but rather must be in a def. | 31 but rather must be in a def. |
| 33 """ | 32 """ |
| 34 self._model.weights = 'this is not an np.ndarray' | 33 self._model.np_weights = 'this is not an np.ndarray' |
| 34 |
| 35 self.assertRaises(TypeError, _NpWeightSettingExpression) |
| 36 |
| 37 def testNpWeightsSetterShapeMismatch(self): |
| 38 |
| 39 def _WeightSettingExpression(): |
| 40 """Wrap the ``self._model.weights = stuff`` expression.""" |
| 41 # This np.ndarray has the wrong shape. |
| 42 self._model.np_weights = np.array([[1,2], [3,4]]) |
| 35 | 43 |
| 36 self.assertRaises(TypeError, _WeightSettingExpression) | 44 self.assertRaises(TypeError, _WeightSettingExpression) |
| 37 | 45 |
| 38 def testWeightsSetterShapeMismatch(self): | |
| 39 def _WeightSettingExpression(): | |
| 40 """Wrap the ``self._model.weights = stuff`` expression.""" | |
| 41 # This np.ndarray has the wrong shape. | |
| 42 self._model.weights = np.array([[1,2], [3,4]]) | |
| 43 | |
| 44 self.assertRaises(TypeError, _WeightSettingExpression) | |
| 45 | |
| 46 def testTrainWeights(self): | 46 def testTrainWeights(self): |
| 47 """Tests that ``TrainWeights`` actually improves the loglikelihood. | 47 """Tests that ``TrainWeights`` actually improves the loglikelihood. |
| 48 | 48 |
| 49 Actually, this is more of a test that we're calling SciPy's BFGS | 49 Actually, this is more of a test that we're calling SciPy's BFGS |
| 50 implementation correctly. But any bugs we find about that will show | 50 implementation correctly. But any bugs we find about that will show |
| 51 up in trying to run this rest rather than in the assertaion failing | 51 up in trying to run this rest rather than in the assertaion failing |
| 52 per se. | 52 per se. |
| 53 """ | 53 """ |
| 54 initial_loglikelihood = self._model.LogLikelihood() | 54 initial_loglikelihood = self._model.LogLikelihood() |
| 55 self._model.TrainWeights(0.5) | 55 self._model.TrainWeights(0.5) |
| 56 trained_loglikelihood = self._model.LogLikelihood() | 56 trained_loglikelihood = self._model.LogLikelihood() |
| 57 self.assertTrue(trained_loglikelihood >= initial_loglikelihood, | 57 self.assertTrue(trained_loglikelihood >= initial_loglikelihood, |
| 58 'Training reduced the loglikelihood from %f to %f,' | 58 'Training reduced the loglikelihood from %f to %f,' |
| 59 ' when it should have increased it!' | 59 ' when it should have increased it!' |
| 60 % (initial_loglikelihood, trained_loglikelihood)) | 60 % (initial_loglikelihood, trained_loglikelihood)) |
| OLD | NEW |