Source code for lib5c.plotters.correlation

"""
Module for plotting pairwise correlation matrices.
"""

import matplotlib.pyplot as plt
import seaborn as sns

from lib5c.util.plotting import plotter


[docs]@plotter def plot_correlation_matrix(matrix, label_values=None, cluster=False, cbar=False, cmap='rocket_r', colorscale=None, despine=False, style='dark', **kwargs): """ Plots a pairwise corrrelation matrix as a heatmap. Parameters ---------- matrix : np.ndarray The pairwise correlation matrix to visualize. label_values : Optional[List[str]] A list of strings labeling the columns of the matrix. If not passed, no labels will be included. cluster : bool Pass True to perform heirarchical clustering on the rows and columns of the matrix. cbar : bool Pass True to include a colorbar. cmap : matplotlib colormap Choose the colormap to use in the heatmap. colorscale : Optional[Tuple[int]] Pass a colorscale to use for the plot. kwargs : kwargs Typical plotter kwargs. Returns ------- pyplot axis The axis plotted on. """ # resolve vmin and vmax from colorscale if colorscale is not None: vmin, vmax = colorscale else: vmin, vmax = (None, None) if len(matrix) > 6: plt.figure(figsize=(len(matrix), len(matrix))) if cluster: cm = sns.clustermap(data=matrix, metric='cosine', square=True, annot=True, xticklabels=label_values, cmap=cmap, yticklabels=label_values, vmin=vmin, vmax=vmax) if label_values is not None: plt.setp(cm.ax_heatmap.yaxis.get_majorticklabels(), rotation=0) plt.setp(cm.ax_heatmap.xaxis.get_majorticklabels(), rotation=45) plt.setp(cm.ax_heatmap.xaxis.get_majorticklabels(), ha='right') else: sns.heatmap(data=matrix, square=True, annot=True, xticklabels=label_values, yticklabels=label_values, cbar=cbar, cmap=cmap, vmin=vmin, vmax=vmax) if label_values is not None: plt.yticks(rotation=0) plt.xticks(rotation=45, ha='right')