| OLD | NEW |
| (Empty) | |
| 1 # Copyright 2016 The Chromium Authors. All rights reserved. |
| 2 # Use of this source code is governed by a BSD-style license that can be |
| 3 # found in the LICENSE file. |
| 4 |
| 5 import numpy as np |
| 6 |
| 7 from crash.loglinear.model import ToFeatureFunction |
| 8 from crash.loglinear.training import TrainableLogLinearModel |
| 9 from crash.loglinear.test.loglinear_testcase import LoglinearTestCase |
| 10 |
| 11 |
| 12 class TrainableLogLinearModelTest(LoglinearTestCase): |
| 13 |
| 14 def setUp(self): |
| 15 super(TrainableLogLinearModelTest, self).setUp() |
| 16 # Normally we wouldn't have *all* possible training data. But this |
| 17 # is just a test; if it doesn't work now, it'll never work. |
| 18 training_data = [(x, x == 7) for x in self._X] |
| 19 self._model = TrainableLogLinearModel( |
| 20 self._Y, training_data, self._feature_function, self._weights) |
| 21 |
| 22 def testWeightsSetterNotAnNdarray(self): |
| 23 def _WeightSettingExpression(): |
| 24 """Wrap the ``self._model.weights = stuff`` expression. |
| 25 |
| 26 The ``assertRaises`` method expects a callable object, so we need |
| 27 to wrap the expression in a def. If we didn't wrap it in a def |
| 28 then we'd throw the exception too early, and ``assertRaises`` |
| 29 would never get called in order to see it. Normally we'd use a |
| 30 lambda for wrapping the expression up, but because the expression |
| 31 we want to check is actually a statement it can't be in a lambda |
| 32 but rather must be in a def. |
| 33 """ |
| 34 self._model.weights = 'this is not an np.ndarray' |
| 35 |
| 36 self.assertRaises(TypeError, _WeightSettingExpression) |
| 37 |
| 38 def testWeightsSetterShapeMismatch(self): |
| 39 def _WeightSettingExpression(): |
| 40 """Wrap the ``self._model.weights = stuff`` expression.""" |
| 41 # This np.ndarray has the wrong shape. |
| 42 self._model.weights = np.array([[1,2], [3,4]]) |
| 43 |
| 44 self.assertRaises(TypeError, _WeightSettingExpression) |
| 45 |
| 46 def testTrainWeights(self): |
| 47 """Tests that ``TrainWeights`` actually improves the loglikelihood. |
| 48 |
| 49 Actually, this is more of a test that we're calling SciPy's BFGS |
| 50 implementation correctly. But any bugs we find about that will show |
| 51 up in trying to run this rest rather than in the assertaion failing |
| 52 per se. |
| 53 """ |
| 54 initial_loglikelihood = self._model.LogLikelihood() |
| 55 self._model.TrainWeights(0.5) |
| 56 trained_loglikelihood = self._model.LogLikelihood() |
| 57 self.assertTrue(trained_loglikelihood >= initial_loglikelihood, |
| 58 'Training reduced the loglikelihood from %f to %f,' |
| 59 ' when it should have increased it!' |
| 60 % (initial_loglikelihood, trained_loglikelihood)) |
| OLD | NEW |