"""
Module for plotting mean-variance relationships.
"""
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from lib5c.util.counts import flatten_obs_and_exp, flatten_obs_and_exp_counts
from lib5c.util.grouping import group_obs_by_exp
from lib5c.util.plotting import plotter
from lib5c.util.parallelization import parallelize_regions
from lib5c.plotters.scatter import scatter as plot_scatter
[docs]@plotter
def plot_mvr(exp, var, obs=None, num_groups=100, group_fractional_tolerance=0.1,
exclude_offdiagonals=5, log=False, logx=False, logy=False,
vst=False, scatter=False, hexbin=False, trim_limits=False,
xlim=None, ylim=None, **kwargs):
"""
Plots a scatterplot of exp vs var.
Optionally, pass obs to instead scatterplot exp vs freshly re-estimated
group variances, and overlay exp vs var as a smooth curve.
Parameters
----------
exp : np.ndarray or dict of np.ndarray
Regional matrix of expected values. Pass a counts dict to combine all
regions together.
var : np.ndarray or dict of np.ndarray
Regional matrix of variances. Pass a counts dict to combine all regions
together.
obs : np.ndarray or dict of np.ndarray, optional
Regional matrix of observed values. Pass a counts dict to combine all
regions together.
num_groups : int
The number of groups to re-esimtate group variances for.
group_fractional_tolerance : float
The width of each group, specified as a fractional tolerance in the
expected value.
exclude_offdiagonals : int
Exclude this many off-diagonals from the variance re-estimation. Pass 0
to exclude only the exact diagonal. Pass -1 to exclude nothing.
log : bool
Pass True to log both exp and var.
logx, logy : bool
Pass True to draw the x- and/or y-axis on a log-scale.
vst : bool
Pass True to log only the exp (e.g., when var is already stabilized).
scatter : bool
Pass True to force plotting exp vs var as a scatterplot when obs is not
passed. By default it will be a line plot.
hexbin : bool
Pass True when ``scatter=True`` to replace the scatterplot with a hexbin
plot.
trim_limits : bool
If `obs` is passed, pass True to trim the x- and y-limits to the range
of the group expected and variance values.
kwargs : kwargs
Typical plotter kwargs.
Returns
-------
pyplot axis
The axis plotted on.
"""
# prepare data for plotting
exp, var, sort_idx, raw_exps, raw_vars = prepare_exp_var_for_plotting(
exp, var, obs=obs, num_groups=num_groups,
group_fractional_tolerance=group_fractional_tolerance,
exclude_offdiagonals=exclude_offdiagonals, log=log, vst=vst)
# plot
if obs is None:
if scatter:
plot_scatter(exp, var, logx=logx, logy=logy, hexbin=hexbin,
xlim=xlim, ylim=ylim)
else:
plt.plot(exp[sort_idx], var[sort_idx], color='r')
else:
plt.scatter(raw_exps, raw_vars, color='b')
plt.plot(exp[sort_idx], var[sort_idx], color='r')
if trim_limits:
plt.xlim(np.min(raw_exps), np.max(raw_exps))
plt.ylim(np.min(raw_vars), np.max(raw_vars))
if logx:
plt.gca().set_xscale('log')
if logy:
plt.gca().set_yscale('log')
plot_mvr_parallel = parallelize_regions(plot_mvr)
[docs]@plotter
def plot_overlay_mvr(exp, var, obs=None, num_groups=100,
group_fractional_tolerance=0.1, exclude_offdiagonals=5,
log=False, logx=False, logy=False, vst=False,
scatter=False, scatter_colors=None, line_colors=None,
legend='outside', **kwargs):
"""
Plots a comparison of mean-variance relationships across regions.
Parameters
----------
exp : dict of np.ndarray
Counts dict of expected values.
var : dict of np.ndarray
Counts dict of variance values.
obs : dict of np.ndarray, optional
Counts dict of observed values.
num_groups : int
The number of groups to re-esimtate group variances for.
group_fractional_tolerance : float
The width of each group, specified as a fractional tolerance in the
expected value.
exclude_offdiagonals : int
Exclude this many off-diagonals from the variance re-estimation. Pass 0
to exclude only the exact diagonal. Pass -1 to exclude nothing.
log : bool
Pass True to log both exp and var.
logx, logy : bool
Pass True to draw the x- and/or y-axis on a log-scale.
vst : bool
Pass True to log only the exp (e.g., when var is already stabilized).
scatter : bool
Pass True to force plotting exp vs var as a scatterplot when obs is not
passed. By default it will be a line plot.
scatter_colors, line_colors : str or dict of str, optional
Mapping from region names to the color to use for that region. Pass None
to use randomly assigned colors. Pass a single string to use the same
color for all regions.
kwargs : kwargs
Typical plotter kwargs.
Returns
-------
pyplot axis
The axis plotted on.
"""
# prepare data for plotting
exp, var, sort_idx, raw_exps, raw_vars = \
prepare_exp_var_for_plotting_parallel(
exp, var, obs=obs, num_groups=num_groups,
group_fractional_tolerance=group_fractional_tolerance,
exclude_offdiagonals=exclude_offdiagonals, log=log, vst=vst)
# resolve colors
regions = sorted(exp.keys())
palette = sns.color_palette('husl', len(regions))
default_colors = {regions[i]: palette[i] for i in range(len(regions))}
if scatter_colors is None:
scatter_colors = default_colors
elif type(scatter_colors) == str:
scatter_colors = {region: scatter_colors for region in regions}
if line_colors is None:
line_colors = default_colors
elif type(line_colors) == str:
line_colors = {region: line_colors for region in regions}
# plot
if obs is None:
if scatter:
for region in regions:
plt.scatter(exp[region], var[region],
color=scatter_colors[region], label=region)
else:
for region in regions:
plt.plot(exp[region][sort_idx[region]],
var[region][sort_idx[region]],
color=line_colors[region], label=region)
else:
for region in regions:
plt.scatter(raw_exps[region], raw_vars[region],
color=scatter_colors[region], label='%s data' % region)
for region in regions:
plt.plot(exp[region][sort_idx[region]],
var[region][sort_idx[region]],
color=line_colors[region], label='%s mvr' % region)
if logx:
plt.gca().set_xscale('log')
if logy:
plt.gca().set_yscale('log')
[docs]def prepare_exp_var_for_plotting(exp, var, obs=None, num_groups=100,
group_fractional_tolerance=0.1,
exclude_offdiagonals=5, log=False, vst=False):
"""
Prepares expected value and variance data for plotting.
Parameters
----------
exp : np.ndarray or dict of np.ndarray
Regional matrix of expected values. Pass a counts dict to combine all
regions together.
var : np.ndarray or dict of np.ndarray
Regional matrix of variances. Pass a counts dict to combine all regions
together.
obs : np.ndarray or dict of np.ndarray, optional
Regional matrix of observed values. Pass a counts dict to combine all
regions together.
num_groups : int
The number of groups to re-esimtate group variances for.
group_fractional_tolerance : float
The width of each group, specified as a fractional tolerance in the
expected value.
exclude_offdiagonals : int
Exclude this many off-diagonals from the variance re-estimation. Pass 0
to exclude only the exact diagonal. Pass -1 to exclude nothing.
log : bool
Pass True to log both exp and var.
vst : bool
Pass True to log only the exp (e.g., when var is already stabilized).
Returns
-------
tuple of np.ndarray
The first and second elements are parallel arrays of the exp, var pairs.
The third element is a sort index into the exp, var pairs. The fourth
and fifth elements are None if obs was not passed. If obs was passed,
they are parallel arrays of raw obs, raw var pairs.
"""
# prepare raw exp, var pairs if obs was passed
if obs is not None:
raw_exps, groups = group_obs_by_exp(
obs, exp,
num_groups=num_groups,
group_fractional_tolerance=group_fractional_tolerance,
exclude_offdiagonals=exclude_offdiagonals)
if vst:
raw_exps = np.log(raw_exps + 1)
raw_vars = np.array([np.nanvar(np.log(group + 1))
for group in groups])
else:
raw_vars = np.array([np.nanvar(group) for group in groups])
if log:
raw_exps = np.log(raw_exps + 1)
raw_vars = np.log(raw_vars + 1)
else:
raw_exps = None
raw_vars = None
# prepare exp, var pairs
if type(exp) == dict:
exp, var = flatten_obs_and_exp_counts(exp, var, log=log)
else:
exp, var = flatten_obs_and_exp(exp, var, log=log)
if vst:
exp = np.log(exp + 1)
# compute sort index for exp/var pairs
sort_idx = np.argsort(exp)
return exp, var, sort_idx, raw_exps, raw_vars
prepare_exp_var_for_plotting_parallel = parallelize_regions(
prepare_exp_var_for_plotting)