"""
Module providing utilities for parallelization of operations on 5C data.
The most important thing exposed in this module is the ``@parallelize_regions``
decorator, which automatically overloads any function to accept regional dicts
of any of its arguments and process the regions in parallel via the
``multiprocess`` package.
The other functions in this module are either example functions to show how it
works (``test_function_one``, etc.) or private helper functions.
"""
from functools import wraps
import inspect
import multiprocessing as mp
import dill
from lib5c.util.pretty_decorator import pretty_decorator
def _regional_param_depth(param, regions, depth=0):
depth += 1
if hasattr(param, 'keys') and param.keys():
if set(param.keys()) == set(regions):
return depth
else:
return _regional_param_depth(param[list(param.keys())[0]],
regions, depth)
return 0
def _regionalize_param(param, region, depth=1):
if depth == 0:
return param
elif depth == 1:
return param[region]
elif depth == 2:
return {upper_key: param[upper_key][region]
for upper_key in param.keys()}
else:
raise NotImplementedError('regions too deep in parameter')
def _regionalize_params(f, region, regions, args, kwargs):
return (
[_regionalize_param(arg, region, _regional_param_depth(arg, regions))
for arg in args],
{key: _regionalize_param(kwargs[key], region,
_regional_param_depth(kwargs[key], regions))
for key in kwargs})
def _unpack_for_map(payload):
fn, args, kwargs = dill.loads(payload)
return fn(args, kwargs)
def _pack_for_map(fn, args_kwargs_list):
return [dill.dumps((fn, args, kwargs)) for args, kwargs in args_kwargs_list]
@pretty_decorator
def parallelize_regions(f, suppress_warnings=True):
"""
A function decorator for parallelizing arbitrary functions to make them
operate on regional dicts in parallel.
Parameters
----------
f : Callable[[...], Any]
The function to parallelize.
suppress_warnings : bool
Make the wrapped functions suppress warnings, which prevents each of
them from writing their own errors when run by ``multiprocess``.
Returns
-------
Callable[[...], Any]
The parallelized version of ``f``.
Notes
-----
When the decorated function is called, the first positional argument is
checked to see if it defines a ``keys()`` function. If it doesn't, then the
original, non-parallel version of the function is executed. If it does, then
the keys of the first positional argument are taken to be the region names.
The other positional and keyword arguments are searched for any dict-like
structures with a matching set of keys, up to a depth of two nested dicts
deep, which are then identified as region-specific parameters. The function
is then executed in parallel for each region, using the appropriate
combination of region-specific and non-region-specific parameters. When all
the parallel executions return, their values are repackaged into a dict. If
the original non-parallel version of the function returns a tuple, the
return type of the parallelized invocation of the function will be a tuple
of dicts.
"""
def wrapped_function(args, kwargs):
if suppress_warnings:
import warnings
warnings.simplefilter('ignore')
return f(*args, **kwargs)
@wraps(f)
def parallel_func(*args, **kwargs):
# if the first arg is not a dict then give them the normal version
if len(args) == 0 or not hasattr(args[0], 'keys'):
return f(*args, **kwargs)
# condense and resolve all kwargs into args
all_arg_names, _, _, defaults = inspect.getargspec(f)
# repackage defaults into a dict for convenience
if defaults:
n_args = len(all_arg_names) - len(defaults)
defaults = {all_arg_names[n_args + i]: defaults[i]
for i in range(len(defaults))}
# this condition is false if args already represents args+kwargs
if len(args) != len(all_arg_names):
args = list(args)
for arg_name in all_arg_names:
if arg_name in kwargs:
args.append(kwargs[arg_name])
elif arg_name in defaults:
args.append(defaults[arg_name])
# guess regions based on keys of first arg
regions = list(args[0].keys())
# construct appropriate args list for parallel execution
args_kwargs_list = [
_regionalize_params(f, region, regions, args, kwargs)
for region in regions]
# process in parallel
try:
num_cpus = mp.cpu_count()
p = mp.Pool(num_cpus)
results = p.map(_unpack_for_map,
_pack_for_map(wrapped_function, args_kwargs_list))
p.close()
except Exception:
print('encountered exception, falling back to series operation')
results = list(map(wrapped_function, *list(zip(*args_kwargs_list))))
# repackage results
if type(results[0]) == tuple:
result = ({regions[i]: results[i][j]
for i in range(len(regions))}
for j in range(len(results[0])))
else:
result = {regions[i]: results[i] for i in range(len(regions))}
return result
return parallel_func
[docs]@parallelize_regions
def test_function_one(count):
return count * 2
[docs]@parallelize_regions
def test_function_two(count, multiplier=4):
return count * multiplier
[docs]@parallelize_regions
def test_function_three(x, y):
return x * y, x + y
[docs]@parallelize_regions
def test_function_four(x, y):
return x + y['s'] + y['t']
[docs]def main():
print(test_function_one(5))
print(test_function_one({'a': 4, 'b': 6}))
print(test_function_two(5))
print(test_function_two({'a': 4, 'b': 6}))
print(test_function_two({'a': 4, 'b': 6}, multiplier=3))
print(test_function_two({'a': 4, 'b': 6}, multiplier={'a': 5, 'b': 10}))
p, s = test_function_three(3, 4)
print(p, s)
p, s = test_function_three({'a': 4, 'b': 6}, 9)
print(p, s)
p, s = test_function_three({'a': 4, 'b': 6}, {'a': 4, 'b': 6})
print(p, s)
print(test_function_four(7, {'s': 1, 't': -1}))
print(test_function_four({'a': 6, 'b': 5},
{'s': {'a': 10, 'b': -10},
't': {'a': -100, 'b': 100}}))
if __name__ == '__main__':
main()