"""
Module containing utilities related to plotting.
"""
import inspect
from functools import wraps
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from lib5c.util.pretty_decorator import pretty_decorator
from lib5c.util.system import check_outdir
from lib5c.util.dictionaries import reduced_get
DEFAULT_RC = {
'text.color' : 'black',
'xtick.color' : 'black',
'ytick.color' : 'black',
'axes.labelcolor': 'black',
'axes.edgecolor' : 'black'
}
@pretty_decorator
def plotter(func):
"""
Multi-purpose decorator for plotting functions.
Decorated functions should accept ``**kwargs`` in their signature. Clients
can then pass a wide variety of kwargs to the decorated function. This
includes all kwargs of ``adjust_plot()`` as well as ``outfile`` (saves the
plot to disk), ``dpi`` (sets the DPI for the saved plot), and ``style``
(sets the seaborn style for the plot).
"""
@wraps(func)
def decorated_func(*args, **kwargs):
# inspect func
names, _, _, defaults = inspect.getargspec(func)
if defaults is None:
defaults = {}
defaults_dict = dict(zip(names[len(names)-len(defaults):], defaults))
kwargs_dict = dict(zip(names[len(names)-len(defaults):],
args[len(args)-len(defaults):]))
# extract params not used by adjust_plot() from **kwargs
dicts_to_search = [defaults_dict, kwargs_dict, kwargs]
ax = reduced_get('ax', dicts_to_search)
outfile = reduced_get('outfile', dicts_to_search)
dpi = reduced_get('dpi', dicts_to_search, 100)
style = reduced_get('style', dicts_to_search, 'ticks')
# construct plot_kwargs, honoring defaults defined by func
adj_names, _, _, adj_defaults = inspect.getargspec(adjust_plot)
plot_kwargs = dict(zip(adj_names[len(adj_names)-len(adj_defaults):],
adj_defaults))
plot_kwargs.update({key: defaults_dict[key] for key in adj_names
if key in defaults_dict})
plot_kwargs.update({key: kwargs_dict[key] for key in adj_names
if key in kwargs_dict})
plot_kwargs.update({key: kwargs[key] for key in adj_names
if key in kwargs})
# special case: if the plotter function has a legend=True kwarg, don't
# redraw the legend
if 'legend' in plot_kwargs and plot_kwargs['legend'] is True and \
'legend' in defaults_dict:
plot_kwargs['legend'] = None
# create other_kwargs as "everything else"
other_kwargs = dict(defaults_dict)
other_kwargs.update({key: kwargs_dict[key]
for key in defaults_dict.keys()
if key in kwargs_dict})
# honor ax
if ax is not None:
plt.sca(ax)
# clear figure if plotting in `outfile` mode
if outfile is not None:
plt.clf()
# do the actual plotting
if style is not None:
# prepare sns
sns.set(color_codes=True)
with sns.axes_style(style, DEFAULT_RC):
retval = func(*args[:len(args)-len(defaults)], **other_kwargs)
else:
retval = func(*args[:len(args) - len(defaults)], **other_kwargs)
# save figure
if outfile is not None:
check_outdir(outfile)
if plot_kwargs:
adjust_plot(**plot_kwargs)
plt.savefig(outfile, dpi=dpi, bbox_inches='tight')
plt.close()
# reset seaborn
sns.reset_orig()
# return something sensible
if retval is None:
return plt.gca()
elif type(retval) == tuple:
return tuple([plt.gca()] + list(retval))
return plt.gca(), retval
return decorated_func
[docs]def adjust_plot(ax=None, xlim=None, ylim=None, xlabel=None, ylabel=None,
xticks=None, yticks=None, title=None, despine=True,
legend=None):
"""
Multipurpose plot adjustment method.
Parameters
----------
ax : pyplot axis
The axis to operate on.
xlim : tuple of numeric
Pass a tuple of the form (min, max) to set the x-limits of the plot.
ylim : tuple of numeric
Pass a tuple of the form (min, max) to set the y-limits of the plot.
xlabel : str
Label for the x-axis.
ylabel : str
Label for the y-axis.
xticks : int, list of int, or tuple of list of ints
Pass an int to use this as the spacing for the xticks. Pass a
``(positions, labels)`` tuple to call ``plt.xticks(positions,labels)``.
Pass anything else to pass it directly to ``plt.xticks()``.
yticks : int, list of int, or tuple of list of ints
Pass an int to use this as the spacing for the yticks. Pass a
``(positions, labels)`` tuple to call ``plt.yticks(positions,labels)``.
Pass anything else to pass it directly to ``plt.yticks()``.
title : str
Title for the plot.
despine : bool
Pass True to despine the plot.
legend : str or bool or None
Pass False to remove the legend, pass True to add a default legend, pass
'outside' to move the legend outside the plot area, pass None to leave
the legend alone.
"""
if ax is None:
ax = plt.gca()
else:
plt.sca(ax)
if xlim is not None:
plt.xlim(xlim)
if ylim is not None:
plt.ylim(ylim)
if xlabel is not None:
plt.xlabel(xlabel)
if ylabel is not None:
plt.ylabel(ylabel)
if xticks is not None:
if type(xticks) == tuple:
plt.xticks(*xticks)
else:
if type(xticks) == int:
start, stop = plt.xlim()
xticks = np.arange(start, stop + xticks, xticks)
plt.xticks(xticks)
if yticks is not None:
if type(yticks) == tuple:
plt.yticks(*yticks)
else:
if type(yticks) == int:
start, stop = plt.ylim()
yticks = np.arange(start, stop + yticks, yticks)
plt.yticks(yticks)
if title is not None:
plt.title(title)
if despine:
sns.despine()
if legend is not None:
if legend == 'outside':
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
if legend is False and ax.legend_:
ax.legend_.remove()
if legend is True:
plt.legend()
[docs]def compute_hexbin_extent(xlim, ylim, logx=False, logy=False):
"""
Helper function for computing the ``extent`` kwarg of ``plt.hexbin()``.
Parameters
----------
xlim, ylim : tuple, optional
Tuple of `(x_min, x_max)` and `(y_min, y_max)`, respectively. If either
is None, no attempt will be made to set the extent.
logx, logy: bool
Whether or not ``plt.hexbin()`` is being called with ``xscale='log'``
and/or ``yscale='log'``, respectively.
Returns
-------
list or None
The extent if it could be computed or None otherwise.
"""
if xlim is None or ylim is None:
return None
xlim = [np.log10(x) if x > 0 else 0 for x in xlim] if logx else list(xlim)
ylim = [np.log10(x) if x > 0 else 0 for x in ylim] if logy else list(ylim)
return xlim + ylim