OLD | NEW |
| (Empty) |
1 # Copyright 2013 The Chromium Authors. All rights reserved. | |
2 # Use of this source code is governed by a BSD-style license that can be | |
3 # found in the LICENSE file. | |
4 | |
5 """Generates short, hopefully readable Python expressions that test variables | |
6 for certain combinations of values. | |
7 """ | |
8 | |
9 import itertools | |
10 import re | |
11 | |
12 | |
13 class ShortExpressionFinder(object): | |
14 """Usage: | |
15 >>> sef = ShortExpressionFinder([('foo', ("a", "b", "c", "d")), | |
16 ... ('bar', (1, 2))]) | |
17 >>> sef.get_expr([("a", 1)]) | |
18 foo=="a" and bar==1 | |
19 >>> sef.get_expr([("a", 1), ("b", 2), ("c", 1)]) | |
20 ((foo=="a" or foo=="c") and bar==1) or (foo=="b" and bar==2) | |
21 | |
22 The returned expressions are of the form | |
23 EXPR ::= EXPR2 ( "or" EXPR2 )* | |
24 EXPR2 ::= EXPR3 ( "and" EXPR3 )* | |
25 EXPR3 ::= VAR "==" VALUE ( "or" VAR "==" VALUE )* | |
26 where all of the comparisons in an EXPR2 involve the same variable. | |
27 Only positive tests are used so that all expressions will evaluate false if | |
28 given an unanticipated combination of values. | |
29 | |
30 A "cheapest" expression is returned. The cost of an expression is a function | |
31 of the number of var==value comparisons and the number of parentheses. | |
32 | |
33 The expression is found by exhaustive search, but it seems to be adequately | |
34 fast even for fairly large sets of configuration variables. | |
35 """ | |
36 | |
37 # These must be positive integers. | |
38 COMPARISON_COST = 1 | |
39 PAREN_COST = 1 | |
40 | |
41 def __init__(self, variables_and_values): | |
42 assert variables_and_values | |
43 for k, vs in variables_and_values: | |
44 assert re.match(r'\w+\Z', k) | |
45 assert vs | |
46 assert (all(isinstance(v, int) for v in vs) or | |
47 all(re.match(r'\w+\Z', v) for v in vs)) | |
48 | |
49 variables, values = zip(*((k, sorted(v)) for k, v in variables_and_values)) | |
50 valuecounts = map(len, values) | |
51 base_tests_by_cost = {} | |
52 | |
53 # Loop over nonempty subsets of values of each variable. This is about 2^n | |
54 # cases where n is the total number of values (currently 2+3=5 in Chrome). | |
55 for subsets in itertools.product(*(range(1, 1 << n) for n in valuecounts)): | |
56 # Supposing values == [['a', 'b', 'c'], [1, 2]], there are six | |
57 # configurations: ('a', 1), ('a', 2), ('b', 1), etc. Each gets a bit, in | |
58 # that order starting from the LSB. Start with the equivalent of | |
59 # set([('a', 1)]) and massage that into the correct set of configs. | |
60 bits = 1 | |
61 shift = 1 | |
62 cost = 0 | |
63 for subset, valuecount in zip(subsets, valuecounts): | |
64 oldbits, bits = bits, 0 | |
65 while subset: | |
66 if subset & 1: | |
67 bits |= oldbits | |
68 cost += self.COMPARISON_COST | |
69 oldbits <<= shift | |
70 subset >>= 1 | |
71 shift *= valuecount | |
72 # Charge an extra set of parens for the whole expression, | |
73 # which is removed later if appropriate. | |
74 cost += self.PAREN_COST * (1 + sum(bool(n & (n-1)) for n in subsets)) | |
75 base_tests_by_cost.setdefault(cost, {})[bits] = subsets | |
76 | |
77 self.variables = variables | |
78 self.values = values | |
79 self.base_tests_by_cost = base_tests_by_cost | |
80 | |
81 def get_expr(self, configs): | |
82 assert configs | |
83 for config in configs: | |
84 assert len(config) == len(self.values) | |
85 assert all(val in vals for val, vals in zip(config, self.values)) | |
86 return self._format_expr(self._get_expr_internal(configs)) | |
87 | |
88 def _get_expr_internal(self, configs): | |
89 bits = 0 | |
90 for config in configs: | |
91 bit = 1 | |
92 n = 1 | |
93 for value, values in zip(config, self.values): | |
94 bit <<= (n * values.index(value)) | |
95 n *= len(values) | |
96 bits |= bit | |
97 notbits = ~bits | |
98 | |
99 def try_partitions(parts, bits): | |
100 for cost, subparts in parts: | |
101 if cost is None: | |
102 return None if bits else () | |
103 try: | |
104 tests = self.base_tests_by_cost[cost] | |
105 except KeyError: | |
106 continue | |
107 for test in tests: | |
108 if (test & bits) and not (test & notbits): | |
109 result = try_partitions(subparts, bits & ~test) | |
110 if result is not None: | |
111 return (tests[test],) + result | |
112 return None | |
113 | |
114 for total_cost in itertools.count(0): | |
115 try: | |
116 return (self.base_tests_by_cost[total_cost + self.PAREN_COST][bits],) | |
117 except KeyError: | |
118 result = try_partitions(tuple(partitions(total_cost, self.PAREN_COST)), | |
119 bits) | |
120 if result is not None: | |
121 return result | |
122 | |
123 def _format_expr(self, expr): | |
124 out = [] | |
125 for expr2 in expr: | |
126 out2 = [] | |
127 for name, values, expr3 in zip(self.variables, self.values, expr2): | |
128 out3 = [] | |
129 for value in values: | |
130 if expr3 & 1: | |
131 if isinstance(value, basestring): | |
132 value = '"%s"' % value | |
133 out3.append('%s==%s' % (name, value)) | |
134 expr3 >>= 1 | |
135 out2.append(' or '.join(out3)) | |
136 if len(out3) > 1 and len(expr2) > 1: | |
137 out2[-1] = '(%s)' % out2[-1] | |
138 out.append(' and '.join(out2)) | |
139 if len(out2) > 1 and len(expr) > 1: | |
140 out[-1] = '(%s)' % out[-1] | |
141 return ' or '.join(out) | |
142 | |
143 | |
144 def partitions(n, minimum): | |
145 """Yields all the ways of expressing n as a sum of integers >= minimum, | |
146 in a slightly odd tree format. Most of the tree is left unevaluated. | |
147 Example: | |
148 partitions(4, 1) ==> | |
149 [1, <[1, <[1, <[1, <end>]>], | |
150 [2, <end>]>], | |
151 [3, <end>]>], | |
152 [2, <[2, <end>]>], | |
153 [4, <end>] | |
154 where <...> is a lazily-evaluated list and end == [None, None]. | |
155 """ | |
156 if n == 0: | |
157 yield (None, None) | |
158 for k in range(n, minimum - 1, -1): | |
159 children = partitions(n - k, k) | |
160 # We could just yield [k, children] here, but that would create a lot of | |
161 # blind alleys with no actual partitions. | |
162 try: | |
163 yield [k, MemoizedIterable(itertools.chain([next(children)], children))] | |
164 except StopIteration: | |
165 pass | |
166 | |
167 | |
168 class MemoizedIterable(object): | |
169 """Wrapper for an iterable that fully evaluates and caches the values the | |
170 first time it is iterated over. | |
171 """ | |
172 def __init__(self, iterable): | |
173 self.iterable = iterable | |
174 def __iter__(self): | |
175 self.iterable = tuple(self.iterable) | |
176 return iter(self.iterable) | |
OLD | NEW |