OLD | NEW |
| (Empty) |
1 # Copyright 2016 The Chromium Authors. All rights reserved. | |
2 # Use of this source code is governed by a BSD-style license that can be | |
3 # found in the LICENSE file. | |
4 | |
5 """Pure Python implementation of the Mann-Whitney U test. | |
6 | |
7 This code is adapted from SciPy: | |
8 https://github.com/scipy/scipy/blob/master/scipy/stats/stats.py | |
9 Which is provided under a BSD-style license. | |
10 | |
11 There is also a JavaScript version in Catapult: | |
12 https://github.com/catapult-project/catapult/blob/master/tracing/third_party/m
annwhitneyu/mannwhitneyu.js | |
13 """ | |
14 | |
15 import itertools | |
16 import math | |
17 | |
18 | |
19 def MannWhitneyU(x, y): | |
20 """Computes the Mann-Whitney rank test on samples x and y. | |
21 | |
22 The distribution of U is approximately normal for large samples. This | |
23 implementation uses the normal approximation, so it's recommended to have | |
24 sample sizes > 20. | |
25 """ | |
26 n1 = len(x) | |
27 n2 = len(y) | |
28 ranked = _RankData(x + y) | |
29 rankx = ranked[0:n1] # get the x-ranks | |
30 u1 = n1*n2 + n1*(n1+1)/2.0 - sum(rankx) # calc U for x | |
31 u2 = n1*n2 - u1 # remainder is U for y | |
32 t = _TieCorrectionFactor(ranked) | |
33 if t == 0: | |
34 raise ValueError('All numbers are identical in mannwhitneyu') | |
35 sd = math.sqrt(t * n1 * n2 * (n1+n2+1) / 12.0) | |
36 | |
37 mean_rank = n1*n2/2.0 + 0.5 | |
38 big_u = max(u1, u2) | |
39 | |
40 z = (big_u - mean_rank) / sd | |
41 return 2 * _NormSf(abs(z)) | |
42 | |
43 | |
44 def _RankData(a): | |
45 """Assigns ranks to data. Ties are given the mean of the ranks of the items. | |
46 | |
47 This is called "fractional ranking": | |
48 https://en.wikipedia.org/wiki/Ranking | |
49 """ | |
50 sorter = _ArgSortReverse(a) | |
51 ranked_min = [0] * len(sorter) | |
52 for i, j in reversed(list(enumerate(sorter))): | |
53 ranked_min[j] = i | |
54 | |
55 sorter = _ArgSort(a) | |
56 ranked_max = [0] * len(sorter) | |
57 for i, j in enumerate(sorter): | |
58 ranked_max[j] = i | |
59 | |
60 return [1 + (x+y)/2.0 for x, y in zip(ranked_min, ranked_max)] | |
61 | |
62 | |
63 def _ArgSort(a): | |
64 """Returns the indices that would sort an array. | |
65 | |
66 Ties are given indices in ordinal order.""" | |
67 return sorted(range(len(a)), key=a.__getitem__) | |
68 | |
69 | |
70 def _ArgSortReverse(a): | |
71 """Returns the indices that would sort an array. | |
72 | |
73 Ties are given indices in reverse ordinal order.""" | |
74 return list(reversed(sorted(range(len(a)), key=a.__getitem__, reverse=True))) | |
75 | |
76 | |
77 def _TieCorrectionFactor(rankvals): | |
78 """Tie correction factor for ties in the Mann-Whitney U test.""" | |
79 arr = sorted(rankvals) | |
80 cnt = [len(list(group)) for _, group in itertools.groupby(arr)] | |
81 size = len(arr) | |
82 if size < 2: | |
83 return 1.0 | |
84 else: | |
85 return 1.0 - sum(x**3 - x for x in cnt) / float(size**3 - size) | |
86 | |
87 | |
88 def _NormSf(x): | |
89 """Survival function of the standard normal distribution. (1 - cdf)""" | |
90 return (1 - math.erf(x/math.sqrt(2))) / 2 | |
OLD | NEW |