Source code for lib5c.plotters.extendable.cluster_extendable_heatmap

"""
Module for the ClusterExtendableHeatmap class, which adds cluster outlining
functionality for the extendable heatmap system.
"""

import numpy as np
from matplotlib import cm

from lib5c.plotters.extendable import BaseExtendableHeatmap
from lib5c.algorithms.clustering.util import reshape_cluster_array_to_dict,\
    center_of_mass, belongs_to


[docs]class ClusterExtendableHeatmap(BaseExtendableHeatmap): """ ExtendableHeatmap mixin class providing cluster outlining functionality. """
[docs] def add_clusters(self, cluster_array, colors=None, weight='100x', outline_color=None, outline_weight='2x', labels=None, fontsize=7): """ Adds clusters to the heatmap surface. Parameters ---------- cluster_array: np.ndarray Array of cluster IDs. Should match size and shape of the underlying array this ExtendableHeatmap was constructed with. colors: 'random' or single color or list/dict of colors or None Pass 'random' for random colors, pass a dict mapping cluster IDs to matplotlib colors to outline each cluster in the indicated color, pass None to skip outlining clusters. weight : numeric or str Pass a numeric to set the linewidth for the cluster outlines. Pass a string ending in "x" (such as "100x") to specify the line width as a multiple of the inverse of the number of pixels in the heatmap. outline_color : matplotlib color or None Pass a matplotlib color to outline the outlines (e.g. with neon green) to make them stand out more. Pass None to skip adding this extra outline. outline_weight : numeric or str ass a numeric to set the linewidth for the outlines of the cluster outlines. Pass a string ending in "x" (such as "2x") to specify the line width as a multiple of the outline linewidth. labels : True, dict of str, or None Pass True to simply label the clusters by their ID. Pass a mapping from cluster IDs to labels to label the clusters with a the labels. Pass None to skip outlining clusters. fontsize : numeric The font size to use for cluster labels. """ clusters = reshape_cluster_array_to_dict(cluster_array) if colors: # resolve line width base_linewidth = weight if type(weight) == str: base_linewidth = float(weight[:-1]) / len(self.array) # resolve outline line width outline_linewidth = base_linewidth * 2 if outline_color is not None and outline_weight is not None: if type(outline_weight) == str: outline_linewidth = float( outline_weight[:-1]) * base_linewidth else: outline_linewidth = base_linewidth + outline_weight # resolve colors if colors is None: pass elif colors == 'random': color_multiplier = 256 / max(len(clusters) - 1, 1) cluster_ids = list(clusters.keys()) colors = {cluster_ids[i]: cm.gist_ncar(i * color_multiplier) for i in range(len(cluster_ids))} elif type(colors) not in [dict, list]: colors = {cluster_id: colors for cluster_id in clusters} # resolve labels if labels is True: labels = {cluster_id: cluster_id for cluster_id in clusters} # add outlines of the outlines first if outline_color is not None: for cluster_id in clusters: self.outline_cluster(clusters[cluster_id], outline_color, linewidth=outline_linewidth) # outline clusters if colors is not None: for cluster_id in clusters: self.outline_cluster(clusters[cluster_id], colors[cluster_id], linewidth=base_linewidth) # label clusters if labels is not None: for cluster_id in clusters: self.label_cluster(clusters[cluster_id], labels[cluster_id], fontsize=fontsize)
[docs] def outline_cluster(self, cluster, color, linewidth=2): """ Outlines a single cluster in the specified color. Parameters ---------- cluster : list of {'x': int, 'y': int} dicts The cluster to outline. color : matplotlib color The color to outline in. linewidth : numeric The linewidth to use. """ # reference to axis to plot to ax = self['root'] for peak in cluster: # top query_peak = {'x': peak['x'], 'y': peak['y'] - 1} if not belongs_to(query_peak, cluster): x = peak['x'] y = peak['y'] ax.plot([x, x + 1], [y, y], c=color, lw=linewidth) # bottom query_peak = {'x': peak['x'], 'y': peak['y'] + 1} if not belongs_to(query_peak, cluster): x = peak['x'] y = peak['y'] ax.plot([x, x + 1], [y + 1, y + 1], c=color, lw=linewidth) # left query_peak = {'x': peak['x'] - 1, 'y': peak['y']} if not belongs_to(query_peak, cluster): x = peak['x'] y = peak['y'] ax.plot([x, x], [y, y + 1], c=color, lw=linewidth) # right query_peak = {'x': peak['x'] + 1, 'y': peak['y']} if not belongs_to(query_peak, cluster): x = peak['x'] y = peak['y'] ax.plot([x + 1, x + 1], [y, y + 1], c=color, lw=linewidth)
[docs] def label_cluster(self, cluster, label, fontsize=7): """ Labels a cluster. Parameters ---------- cluster : list of {'x': int, 'y': int} dicts The cluster to label. label : str The string to label the cluster with. fontsize : numeric The fontsize to use for the label. """ centroid = center_of_mass(cluster) + np.array([0.5, 0.5]) self['root'].text(centroid[0], centroid[1], str(label), fontsize=fontsize, ha='center', va='center')