"""
Module for the ClusterExtendableHeatmap class, which adds cluster outlining
functionality for the extendable heatmap system.
"""
import numpy as np
from matplotlib import cm
from lib5c.plotters.extendable.base_extendable_heatmap import \
BaseExtendableHeatmap
from lib5c.algorithms.clustering.util import reshape_cluster_array_to_dict,\
center_of_mass, belongs_to
[docs]class ClusterExtendableHeatmap(BaseExtendableHeatmap):
"""
ExtendableHeatmap mixin class providing cluster outlining functionality.
"""
[docs] def add_clusters(self, cluster_array, colors=None, weight='100x',
outline_color=None, outline_weight='2x', labels=None,
fontsize=7):
"""
Adds clusters to the heatmap surface.
Parameters
----------
cluster_array: np.ndarray
Array of cluster IDs. Should match size and shape of the underlying
array this ExtendableHeatmap was constructed with.
colors: 'random' or single color or list/dict of colors or None
Pass 'random' for random colors, pass a dict mapping cluster IDs to
matplotlib colors to outline each cluster in the indicated color,
pass None to skip outlining clusters.
weight : numeric or str
Pass a numeric to set the linewidth for the cluster outlines. Pass a
string ending in "x" (such as "100x") to specify the line width as a
multiple of the inverse of the number of pixels in the heatmap.
outline_color : matplotlib color or None
Pass a matplotlib color to outline the outlines (e.g. with neon
green) to make them stand out more. Pass None to skip adding this
extra outline.
outline_weight : numeric or str
ass a numeric to set the linewidth for the outlines of the cluster
outlines. Pass a string ending in "x" (such as "2x") to specify the
line width as a multiple of the outline linewidth.
labels : True, dict of str, or None
Pass True to simply label the clusters by their ID. Pass a mapping
from cluster IDs to labels to label the clusters with a the labels.
Pass None to skip outlining clusters.
fontsize : numeric
The font size to use for cluster labels.
"""
clusters = reshape_cluster_array_to_dict(cluster_array)
if colors:
# resolve line width
base_linewidth = weight
if type(weight) == str:
base_linewidth = float(weight[:-1]) / len(self.array)
# resolve outline line width
outline_linewidth = base_linewidth * 2
if outline_color is not None and outline_weight is not None:
if type(outline_weight) == str:
outline_linewidth = float(
outline_weight[:-1]) * base_linewidth
else:
outline_linewidth = base_linewidth + outline_weight
# resolve colors
if colors is None:
pass
elif colors == 'random':
color_multiplier = 256 / max(len(clusters) - 1, 1)
cluster_ids = list(clusters.keys())
colors = {cluster_ids[i]: cm.gist_ncar(i * color_multiplier)
for i in range(len(cluster_ids))}
elif type(colors) not in [dict, list]:
colors = {cluster_id: colors for cluster_id in clusters}
# resolve labels
if labels is True:
labels = {cluster_id: cluster_id for cluster_id in clusters}
# add outlines of the outlines first
if outline_color is not None:
for cluster_id in clusters:
self.outline_cluster(clusters[cluster_id], outline_color,
linewidth=outline_linewidth)
# outline clusters
if colors is not None:
for cluster_id in clusters:
self.outline_cluster(clusters[cluster_id],
colors[cluster_id],
linewidth=base_linewidth)
# label clusters
if labels is not None:
for cluster_id in clusters:
self.label_cluster(clusters[cluster_id], labels[cluster_id],
fontsize=fontsize)
[docs] def outline_cluster(self, cluster, color, linewidth=2):
"""
Outlines a single cluster in the specified color.
Parameters
----------
cluster : list of {'x': int, 'y': int} dicts
The cluster to outline.
color : matplotlib color
The color to outline in.
linewidth : numeric
The linewidth to use.
"""
# reference to axis to plot to
ax = self['root']
for peak in cluster:
# top
query_peak = {'x': peak['x'], 'y': peak['y'] - 1}
if not belongs_to(query_peak, cluster):
x = peak['x']
y = peak['y']
ax.plot([x, x + 1], [y, y], c=color, lw=linewidth)
# bottom
query_peak = {'x': peak['x'], 'y': peak['y'] + 1}
if not belongs_to(query_peak, cluster):
x = peak['x']
y = peak['y']
ax.plot([x, x + 1], [y + 1, y + 1], c=color, lw=linewidth)
# left
query_peak = {'x': peak['x'] - 1, 'y': peak['y']}
if not belongs_to(query_peak, cluster):
x = peak['x']
y = peak['y']
ax.plot([x, x], [y, y + 1], c=color, lw=linewidth)
# right
query_peak = {'x': peak['x'] + 1, 'y': peak['y']}
if not belongs_to(query_peak, cluster):
x = peak['x']
y = peak['y']
ax.plot([x + 1, x + 1], [y, y + 1], c=color, lw=linewidth)
[docs] def label_cluster(self, cluster, label, fontsize=7):
"""
Labels a cluster.
Parameters
----------
cluster : list of {'x': int, 'y': int} dicts
The cluster to label.
label : str
The string to label the cluster with.
fontsize : numeric
The fontsize to use for the label.
"""
centroid = center_of_mass(cluster) + np.array([0.5, 0.5])
self['root'].text(centroid[0], centroid[1], str(label),
fontsize=fontsize, ha='center', va='center')