OLD | NEW |
(Empty) | |
| 1 # Copyright 2017 The Chromium Authors. All rights reserved. |
| 2 # Use of this source code is governed by a BSD-style license that can be |
| 3 # found in the LICENSE file. |
| 4 |
| 5 """Helpers related to multiprocessing.""" |
| 6 |
| 7 import atexit |
| 8 import logging |
| 9 import multiprocessing |
| 10 import multiprocessing.dummy |
| 11 import os |
| 12 import sys |
| 13 import threading |
| 14 import traceback |
| 15 |
| 16 |
| 17 DISABLE_ASYNC = os.environ.get('SUPERSIZE_DISABLE_ASYNC') == '1' |
| 18 if DISABLE_ASYNC: |
| 19 logging.debug('Running in synchronous mode.') |
| 20 |
| 21 _all_pools = None |
| 22 _is_child_process = False |
| 23 _silence_exceptions = False |
| 24 |
| 25 |
| 26 class _ImmediateResult(object): |
| 27 def __init__(self, value): |
| 28 self._value = value |
| 29 |
| 30 def get(self): |
| 31 return self._value |
| 32 |
| 33 def wait(self): |
| 34 pass |
| 35 |
| 36 def ready(self): |
| 37 return True |
| 38 |
| 39 def successful(self): |
| 40 return True |
| 41 |
| 42 |
| 43 class _ExceptionWrapper(object): |
| 44 """Used to marshal exception messages back to main process.""" |
| 45 def __init__(self, msg): |
| 46 self.msg = msg |
| 47 |
| 48 |
| 49 class _FuncWrapper(object): |
| 50 """Runs on the fork()'ed side to catch exceptions and spread *args.""" |
| 51 def __init__(self, func): |
| 52 global _is_child_process |
| 53 _is_child_process = True |
| 54 self._func = func |
| 55 |
| 56 def __call__(self, args, _=None): |
| 57 try: |
| 58 return self._func(*args) |
| 59 except: # pylint: disable=bare-except |
| 60 # multiprocessing is supposed to catch and return exceptions automatically |
| 61 # but it doesn't seem to work properly :(. |
| 62 logging.warning('CAUGHT EXCEPTION') |
| 63 return _ExceptionWrapper(traceback.format_exc()) |
| 64 |
| 65 |
| 66 class _WrappedResult(object): |
| 67 """Allows for host-side logic to be run after child process has terminated. |
| 68 |
| 69 * Unregisters associated pool _all_pools. |
| 70 * Raises exception caught by _FuncWrapper. |
| 71 * Allows for custom unmarshalling of return value. |
| 72 """ |
| 73 def __init__(self, result, pool=None, decode_func=None): |
| 74 self._result = result |
| 75 self._pool = pool |
| 76 self._decode_func = decode_func |
| 77 |
| 78 def get(self): |
| 79 self.wait() |
| 80 value = self._result.get() |
| 81 _CheckForException(value) |
| 82 if not self._decode_func or not self._result.successful(): |
| 83 return value |
| 84 return self._decode_func(value) |
| 85 |
| 86 def wait(self): |
| 87 self._result.wait() |
| 88 if self._pool: |
| 89 _all_pools.remove(self._pool) |
| 90 self._pool = None |
| 91 |
| 92 def ready(self): |
| 93 return self._result.ready() |
| 94 |
| 95 def successful(self): |
| 96 return self._result.successful() |
| 97 |
| 98 |
| 99 def _TerminatePools(): |
| 100 """Calls .terminate() on all active process pools. |
| 101 |
| 102 Not supposed to be necessary according to the docs, but seems to be required |
| 103 when child process throws an exception or Ctrl-C is hit. |
| 104 """ |
| 105 global _silence_exceptions |
| 106 _silence_exceptions = True |
| 107 # Child processes cannot have pools, but atexit runs this function because |
| 108 # it was registered before fork()ing. |
| 109 if _is_child_process: |
| 110 return |
| 111 def close_pool(pool): |
| 112 try: |
| 113 pool.terminate() |
| 114 except: # pylint: disable=bare-except |
| 115 pass |
| 116 |
| 117 for i, pool in enumerate(_all_pools): |
| 118 # Without calling terminate() on a separate thread, the call can block |
| 119 # forever. |
| 120 thread = threading.Thread(name='Pool-Terminate-{}'.format(i), |
| 121 target=close_pool, args=(pool,)) |
| 122 thread.daemon = True |
| 123 thread.start() |
| 124 |
| 125 |
| 126 def _CheckForException(value): |
| 127 if isinstance(value, _ExceptionWrapper): |
| 128 global _silence_exceptions |
| 129 if not _silence_exceptions: |
| 130 _silence_exceptions = True |
| 131 logging.error('Subprocess raised an exception:\n%s', value.msg) |
| 132 sys.exit(1) |
| 133 |
| 134 |
| 135 def _MakeProcessPool(*args): |
| 136 global _all_pools |
| 137 ret = multiprocessing.Pool(*args) |
| 138 if _all_pools is None: |
| 139 _all_pools = [] |
| 140 atexit.register(_TerminatePools) |
| 141 _all_pools.append(ret) |
| 142 return ret |
| 143 |
| 144 |
| 145 def ForkAndCall(func, args, decode_func=None): |
| 146 """Runs |func| in a fork'ed process. |
| 147 |
| 148 Returns: |
| 149 A Result object (call .get() to get the return value) |
| 150 """ |
| 151 if DISABLE_ASYNC: |
| 152 pool = None |
| 153 result = _ImmediateResult(func(*args)) |
| 154 else: |
| 155 pool = _MakeProcessPool(1) |
| 156 result = pool.apply_async(_FuncWrapper(func), (args,)) |
| 157 pool.close() |
| 158 return _WrappedResult(result, pool=pool, decode_func=decode_func) |
| 159 |
| 160 |
| 161 def BulkForkAndCall(func, arg_tuples): |
| 162 """Calls |func| in a fork'ed process for each set of args within |arg_tuples|. |
| 163 |
| 164 Yields the return values as they come in. |
| 165 """ |
| 166 pool_size = min(len(arg_tuples), multiprocessing.cpu_count()) |
| 167 if DISABLE_ASYNC: |
| 168 for args in arg_tuples: |
| 169 yield func(*args) |
| 170 return |
| 171 pool = _MakeProcessPool(pool_size) |
| 172 wrapped_func = _FuncWrapper(func) |
| 173 for result in pool.imap_unordered(wrapped_func, arg_tuples): |
| 174 _CheckForException(result) |
| 175 yield result |
| 176 pool.close() |
| 177 pool.join() |
| 178 _all_pools.remove(pool) |
| 179 |
| 180 |
| 181 def CallOnThread(func, *args, **kwargs): |
| 182 """Calls |func| on a new thread and returns a promise for its return value.""" |
| 183 if DISABLE_ASYNC: |
| 184 return _ImmediateResult(func(*args, **kwargs)) |
| 185 pool = multiprocessing.dummy.Pool(1) |
| 186 result = pool.apply_async(func, args=args, kwds=kwargs) |
| 187 pool.close() |
| 188 return result |
| 189 |
| 190 |
| 191 def EncodeDictOfLists(d, key_transform=None): |
| 192 """Serializes a dict where values are lists of strings.""" |
| 193 keys = iter(d) |
| 194 if key_transform: |
| 195 keys = (key_transform(k) for k in keys) |
| 196 keys = '\x01'.join(keys) |
| 197 values = '\x01'.join('\x02'.join(x) for x in d.itervalues()) |
| 198 return keys, values |
| 199 |
| 200 |
| 201 def DecodeDictOfLists(encoded_keys, encoded_values, key_transform=None): |
| 202 """Deserializes a dict where values are lists of strings.""" |
| 203 keys = encoded_keys.split('\x01') |
| 204 if key_transform: |
| 205 keys = (key_transform(k) for k in keys) |
| 206 values = encoded_values.split('\x01') |
| 207 ret = {} |
| 208 for i, key in enumerate(keys): |
| 209 ret[key] = values[i].split('\x02') |
| 210 return ret |
OLD | NEW |