# !/usr/bin/env python3
import os, sys
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pylab as plt
from scipy.cluster import hierarchy
import collections
import warnings
import copy
from .utils import mm2inch

from .utils import (

[docs]class heatmapPlotter: def __init__( self, data=None, vmin=None, vmax=None, cmap="bwr", center=None, robust=True, annot=None, fmt=".2g", annot_kws=None, cbar=True, cbar_kws=None, xlabel=None, ylabel=None, xticklabels=True, yticklabels=True, mask=None, na_col="white", ): """Initialize the plotting object.""" if isinstance(data, pd.DataFrame): plot_data = data.values else: plot_data = np.asarray(data) data = pd.DataFrame(plot_data) # Validate the mask and convert to DataFrame mask = _check_mask(data, mask) plot_data =, plot_data) # Get good names for the rows and columns xtickevery = 1 if isinstance(xticklabels, int): xtickevery = xticklabels xticklabels = _index_to_ticklabels(data.columns) elif xticklabels is True: xticklabels = _index_to_ticklabels(data.columns) elif xticklabels is False: xticklabels = [] ytickevery = 1 if isinstance(yticklabels, int): ytickevery = yticklabels yticklabels = _index_to_ticklabels(data.index) elif yticklabels is True: yticklabels = _index_to_ticklabels(data.index) elif yticklabels is False: yticklabels = [] if not len(xticklabels): self.xticks = [] self.xticklabels = [] elif isinstance(xticklabels, str) and xticklabels == "auto": self.xticks = "auto" self.xticklabels = _index_to_ticklabels(data.columns) else: self.xticks, self.xticklabels = _skip_ticks(xticklabels, xtickevery) if not len(yticklabels): self.yticks = [] self.yticklabels = [] elif isinstance(yticklabels, str) and yticklabels == "auto": self.yticks = "auto" self.yticklabels = _index_to_ticklabels(data.index) else: self.yticks, self.yticklabels = _skip_ticks(yticklabels, ytickevery) # Get good names for the axis labels xlabel = _index_to_label(data.columns) if xlabel is None else xlabel ylabel = _index_to_label(data.index) if ylabel is None else ylabel self.xlabel = xlabel if xlabel is not None else "" self.ylabel = ylabel if ylabel is not None else "" self.na_col = na_col # Determine good default values for the colormapping self._determine_cmap_params(plot_data, vmin, vmax, cmap, center, robust) # Sort out the annotations if annot is None or annot is False: annot = False annot_data = None else: if isinstance(annot, bool): annot_data = plot_data else: annot_data = np.asarray(annot) if annot_data.shape != plot_data.shape: err = "`data` and `annot` must have same shape." raise ValueError(err) annot = True # Save other attributes to the object = data self.plot_data = plot_data self.annot = annot self.annot_data = annot_data self.fmt = fmt self.annot_kws = {} if annot_kws is None else annot_kws.copy() self.cbar = cbar self.cbar_kws = {} if cbar_kws is None else cbar_kws.copy() self.cbar_kws.setdefault("aspect", 5) self.cbar_kws.setdefault("fraction", 0.08) self.cbar_kws.setdefault("shrink", 0.5) def _determine_cmap_params(self, plot_data, vmin, vmax, cmap, center, robust): """Use some heuristics to set good defaults for colorbar and range.""" # plot_data is a instance calc_data = plot_data.astype(float).filled(np.nan) if vmin is None: if robust: vmin = np.nanpercentile(calc_data, 2) else: vmin = np.nanmin(calc_data) if vmax is None: if robust: vmax = np.nanpercentile(calc_data, 98) else: vmax = np.nanmax(calc_data) self.vmin, self.vmax = vmin, vmax # Choose default colormaps if not provided if cmap is None: if center is None: try: self.cmap = get_colormap("turbo").copy() except: self.cmap = get_colormap("turbo") else: try: self.cmap = get_colormap("exp1").copy() except: self.cmap = get_colormap("exp1") elif isinstance(cmap, str): try: self.cmap = get_colormap(cmap).copy() except: self.cmap = get_colormap(cmap) elif isinstance(cmap, list): self.cmap = matplotlib.colors.ListedColormap(cmap) else: self.cmap = cmap self.cmap.set_bad(color=self.na_col) # set the color for NaN values # Recenter a divergent colormap if center is not None: # Copy bad values # in matplotlib<3.2 only masked values are honored with "bad" color spec # (see bad = self.cmap([np.nan]))[ 0 ] # set the first color as the na_color # under/over values are set for sure when cmap extremes # do not map to the same color as +-inf under = self.cmap(-np.inf) over = self.cmap(np.inf) under_set = under != self.cmap(0) over_set = over != self.cmap(self.cmap.N - 1) vrange = max(vmax - center, center - vmin) normlize = matplotlib.colors.Normalize(center - vrange, center + vrange) cmin, cmax = normlize([vmin, vmax]) cc = np.linspace(cmin, cmax, 256) self.cmap = matplotlib.colors.ListedColormap(self.cmap(cc)) # self.cmap.set_bad(bad) if under_set: self.cmap.set_under( under ) # set the color of -np.inf as the color for low out-of-range values. if over_set: self.cmap.set_over(over) def _annotate_heatmap(self, ax, mesh): """Add textual labels with the value in each cell.""" mesh.update_scalarmappable() height, width = self.annot_data.shape xpos, ypos = np.meshgrid(np.arange(width) + 0.5, np.arange(height) + 0.5) for x, y, m, color, val in zip( xpos.flat, ypos.flat, mesh.get_array(), mesh.get_facecolors(), self.annot_data.flat, ): if m is not lum = _calculate_luminance(color) text_color = ".15" if lum > 0.408 else "w" if not self.fmt is None: annotation = ("{:" + self.fmt + "}").format(val) else: annotation = val text_kwargs = dict(color=text_color, ha="center", va="center") text_kwargs.update(self.annot_kws) ax.text(x, y, annotation, **text_kwargs) def _set_axes_label( self, ax, xlabel_kws, xlabel_bbox_kws, ylabel_kws, ylabel_bbox_kws, xlabel_side, ylabel_side, xlabel_pad, ylabel_pad, ): # xlabel_kws: alpha,color,fontfamily,fontname,fontproperties,fontsize,fontstyle,fontweight,label,rasterized, # rotation,rotation_mode(default,anchor),visible, zorder,verticalalignment,horizontalalignment if xlabel_kws is None: xlabel_kws = {} xlabel_kws.setdefault("verticalalignment", "center") xlabel_kws.setdefault("horizontalalignment", "center") ax.xaxis.label.update( xlabel_kws ) # print( or # xlabel_bbox_kws:alpha,clip_box, clip_on,edgecolor,facecolor,fill,height,in_layout,label,linestyle, # linewidth,rasterized,visible,width if not xlabel_bbox_kws is None: ax.xaxis.label.set_bbox( xlabel_bbox_kws ) # ax.xaxis.label.get_bbox_patch().properties() if ylabel_kws is None: ylabel_kws = {} ylabel_kws.setdefault("horizontalalignment", "center") # left, right ylabel_kws.setdefault( "verticalalignment", "center" ) # top', 'bottom', 'center', 'baseline', 'center_baseline' ax.yaxis.label.update(ylabel_kws) if not ylabel_bbox_kws is None: ax.yaxis.label.set_bbox( ylabel_bbox_kws ) # ax.xaxis.label.get_bbox_patch().properties() if xlabel_side: ax.xaxis.set_label_position(xlabel_side) # ax.xaxis.label.update_bbox_position_size(ax.figure.canvas.get_renderer()) if ylabel_side: ax.yaxis.set_label_position(ylabel_side) # ax.xaxis.labelpad = 10 #0.12 * ax.figure.dpi # 0.12 inches = 3mm # ax.yaxis.labelpad = 10 #0.12 * ax.figure.dpi # ax.figure.tight_layout(rect=[0, 0, 1, 1]) _draw_figure(ax.figure) # set the xlabel bbox patch color and width, make the width equal to the width of ax.get_window_extent().width # xlabel_bb = ax.xaxis.label.get_bbox_patch() # # ylabel_bb = ax.yaxis.label.get_bbox_patch() # cid = ax.figure.canvas.mpl_connect('resize_event', on_resize) # cid2 = ax.figure.canvas.mpl_connect('draw_event', on_resize) def _set_tick_label( self, ax, xticklabels_side, yticklabels_side, xticklabels_kws, yticklabels_kws ): # position, (0,0) is at the left top corner. if xticklabels_side == "top": ax.xaxis.tick_top() elif xticklabels_side == "bottom": ax.xaxis.tick_bottom() if yticklabels_side == "left": ax.yaxis.tick_left() elif yticklabels_side == "right": ax.yaxis.tick_right() # xticklabel_kwas: axis (x,y,both), which (major,minor,both),reset (True,False), direction (in, out, inout), # length, width, color (tick color), pad, labelsize, labelcolor, colors (for both tick and label), # zorder, bottom, top, left, right (bool), labelbottom, labeltop, labelleft,labelright (bool), # labelrotation,grid_color,grid_alpha,grid_linewidth,grid_linestyle; ?matplotlib.axes.Axes.tick_params if not xticklabels_kws is None: ax.xaxis.set_tick_params(**xticklabels_kws) else: xticklabels_kws = {} if not yticklabels_kws is None: ax.yaxis.set_tick_params(**yticklabels_kws) else: yticklabels_kws = {} ha = None if ( xticklabels_kws.get("rotation", 0) > 0 or xticklabels_kws.get("labelrotation", 0) > 0 ): if xticklabels_side == "top": ha = "left" else: ha = "right" elif ( xticklabels_kws.get("rotation", 0) < 0 or xticklabels_kws.get("labelrotation", 0) < 0 ): if xticklabels_side == "top": ha = "right" else: ha = "left" if not ha is None: plt.setp(ax.get_xticklabels(), rotation_mode="anchor", ha=ha)
[docs] def plot( self, ax, cax, xlabel_kws, xlabel_bbox_kws, ylabel_kws, ylabel_bbox_kws, xlabel_side, ylabel_side, xlabel_pad, ylabel_pad, xticklabels_side, yticklabels_side, xticklabels_kws, yticklabels_kws, kws, ): """Draw the heatmap on the provided Axes.""" # Remove all the Axes spines despine(ax=ax, left=True, bottom=True) # setting vmin/vmax in addition to norm is deprecated # so avoid setting if norm is set if "norm" not in kws: kws.setdefault("vmin", self.vmin) kws.setdefault("vmax", self.vmax) # Draw the heatmap mesh = ax.pcolormesh(self.plot_data, cmap=self.cmap, **kws) # Set the axis limits ax.set(xlim=(0,[1]), ylim=(0,[0])) # Invert the y axis to show the plot in matrix form ax.invert_yaxis() # from top to bottom # Possibly add a colorbar if self.cbar: cb = ax.figure.colorbar(mesh, cax, ax, **self.cbar_kws) cb.outline.set_linewidth(0) # If rasterized is passed to pcolormesh, also rasterize the # colorbar to avoid white lines on the PDF rendering if kws.get("rasterized", False): cb.solids.set_rasterized(True) # Add row and column labels if isinstance(self.xticks, str) and self.xticks == "auto": xticks, xticklabels = _auto_ticks(ax, self.xticklabels, 0) else: xticks, xticklabels = self.xticks, self.xticklabels if isinstance(self.yticks, str) and self.yticks == "auto": yticks, yticklabels = _auto_ticks(ax, self.yticklabels, 1) else: yticks, yticklabels = self.yticks, self.yticklabels ax.set(xticks=xticks, yticks=yticks) xtl = ax.set_xticklabels(xticklabels) ytl = ax.set_yticklabels(yticklabels, rotation="vertical") plt.setp(ytl, va="center") plt.setp(xtl, ha="center") # Possibly rotate them if they overlap _draw_figure(ax.figure) if axis_ticklabels_overlap(xtl): plt.setp(xtl, rotation="vertical") if axis_ticklabels_overlap(ytl): plt.setp(ytl, rotation="horizontal") # Annotate the cells with the formatted values if self.annot: self._annotate_heatmap(ax, mesh) # Add the axis labels ax.set(xlabel=self.xlabel, ylabel=self.ylabel) # put set tick label in the front of set axes label. self._set_tick_label( ax, xticklabels_side, yticklabels_side, xticklabels_kws, yticklabels_kws ) self._set_axes_label( ax, xlabel_kws, xlabel_bbox_kws, ylabel_kws, ylabel_bbox_kws, xlabel_side, ylabel_side, xlabel_pad, ylabel_pad, )
[docs]def heatmap( data, xlabel=None, ylabel=None, xlabel_side="bottom", ylabel_side="left", vmin=None, vmax=None, cmap=None, center=None, robust=False, cbar=True, cbar_kws=None, cbar_ax=None, square=False, xlabel_kws=None, ylabel_kws=None, xlabel_bbox_kws=None, ylabel_bbox_kws=None, xlabel_pad=None, ylabel_pad=None, xticklabels="auto", yticklabels="auto", xticklabels_side="bottom", yticklabels_side="left", xticklabels_kws=None, yticklabels_kws=None, mask=None, na_col="white", ax=None, annot=None, fmt=".2g", annot_kws=None, linewidths=0, linecolor="white", **kwargs ): """ Plot heatmap. Parameters ---------- data: dataframe pandas dataframe xlabel / ylabel: bool True, False, or list of xlabels xlabel_side / ylabel_side: str bottom or top vmax, vmin: float the maximal and minimal values for cmap colorbar. center, robust: the same as seaborn.heatmap xlabel_kws / ylabel_kws: parameter from """ plotter = heatmapPlotter( data=data, vmin=vmin, vmax=vmax, cmap=cmap, center=center, robust=robust, annot=annot, fmt=fmt, annot_kws=annot_kws, cbar=cbar, cbar_kws=cbar_kws, xlabel=xlabel, ylabel=ylabel, xticklabels=xticklabels, yticklabels=yticklabels, mask=mask, na_col=na_col, ) # Add the pcolormesh kwargs here kwargs["linewidths"] = linewidths kwargs["edgecolor"] = linecolor # Draw the plot and return the Axes if ax is None: ax = plt.gca() if square: ax.set_aspect("equal") if xlabel_pad is None: xlabel_pad = 0.3 # 30% of the size of mutation_size (fontsize) if ylabel_pad is None: ylabel_pad = 0.3 # 0.04 * ax.figure.dpi / 16. plotter.plot( ax, cbar_ax, xlabel_kws, xlabel_bbox_kws, ylabel_kws, ylabel_bbox_kws, xlabel_side, ylabel_side, xlabel_pad, ylabel_pad, xticklabels_side, yticklabels_side, xticklabels_kws, yticklabels_kws, kwargs, ) return ax
[docs]def plot_heatmap( data, vmin=None, vmax=None, cmap=None, center=None, robust=False, annot=None, fmt=".2g", annot_kws=None, xticklabels=True, yticklabels=True, mask=None, na_col="white", ax=None, linewidths=0, linecolor="white", **kwargs ): """ Plot heatmap. Parameters ---------- data: dataframe pandas dataframe vmax, vmin: float the maximal and minimal values for cmap colorbar. center, robust: the same as seaborn.heatmap annot: bool whether to add annotation for values fmt: str annotation format. anno_kws: dict passed to ax.text xticklabels,yticklabels: bool whether to show ticklabels """ if isinstance(data, pd.DataFrame): plot_data = data.values else: plot_data = np.asarray(data) data = pd.DataFrame(plot_data) # Validate the mask and convert to DataFrame mask = _check_mask(data, mask) plot_data =, plot_data) # Get good names for the rows and columns xtickevery=1 if isinstance(xticklabels,int): xtickevery=xticklabels xticklabels=True if xticklabels is False: xticks = [] xticklabels = [] elif xticklabels=='auto': xticks = "auto" xticklabels = _index_to_ticklabels(data.columns) else: #True xticklabels = _index_to_ticklabels(data.columns) xticks, xticklabels = _skip_ticks(xticklabels, xtickevery) ytickevery = 1 if isinstance(yticklabels, int): ytickevery = yticklabels yticklabels = True if yticklabels is False: yticks = [] yticklabels = [] elif yticklabels=='auto': yticks = "auto" yticklabels = _index_to_ticklabels(data.index) else: yticklabels = _index_to_ticklabels(data.index) yticks, yticklabels = _skip_ticks(yticklabels, ytickevery) # Determine good default values for the colormapping calc_data = plot_data.astype(float).filled(np.nan) if vmin is None: if robust: vmin = np.nanpercentile(calc_data, 2) else: vmin = np.nanmin(calc_data) if vmax is None: if robust: vmax = np.nanpercentile(calc_data, 98) else: vmax = np.nanmax(calc_data) # Choose default colormaps if not provided if isinstance(cmap, str): try: cmap = get_colormap(cmap).copy() except: cmap = get_colormap(cmap) cmap.set_bad(color=na_col) # set the color for NaN values # Recenter a divergent colormap if center is not None: # bad = cmap([np.nan]))[0] # set the first color as the na_color under = cmap(-np.inf) over = cmap(np.inf) under_set = under != cmap(0) over_set = over != cmap(cmap.N - 1) vrange = max(vmax - center, center - vmin) normlize = matplotlib.colors.Normalize(center - vrange, center + vrange) cmin, cmax = normlize([vmin, vmax]) cc = np.linspace(cmin, cmax, 256) cmap = matplotlib.colors.ListedColormap(cmap(cc)) # self.cmap.set_bad(bad) if under_set: cmap.set_under( under ) # set the color of -np.inf as the color for low out-of-range values. if over_set: cmap.set_over(over) # Sort out the annotations if annot is None or annot is False: annot = False annot_data = None else: if isinstance(annot, bool): annot_data = plot_data else: annot_data = np.asarray(annot) if annot_data.shape != plot_data.shape: err = "`data` and `annot` must have same shape." raise ValueError(err) annot = True if annot_kws is None: annot_kws = {} # Add the pcolormesh kwargs here kwargs["linewidths"] = linewidths kwargs["edgecolor"] = linecolor # Draw the plot and return the Axes despine(ax=ax, left=True, bottom=True) if "norm" not in kwargs: kwargs.setdefault("vmin", vmin) kwargs.setdefault("vmax", vmax) # Draw the heatmap mesh = ax.pcolormesh(plot_data, cmap=cmap, **kwargs) # Set the axis limits ax.set(xlim=(0, data.shape[1]), ylim=(0, data.shape[0])) # Invert the y axis to show the plot in matrix form ax.invert_yaxis() # from top to bottom, left to right # Add row and column labels if isinstance(xticks, str) and xticks == "auto": xticks, xticklabels = _auto_ticks(ax, xticklabels, 0) if isinstance(yticks, str) and yticks == "auto": yticks, yticklabels = _auto_ticks(ax, yticklabels, 1) ax.set(xticks=xticks, yticks=yticks) xtl = ax.set_xticklabels(xticklabels) ytl = ax.set_yticklabels(yticklabels, rotation="vertical") _draw_figure(ax.figure) if axis_ticklabels_overlap(xtl): plt.setp(xtl, rotation="vertical") if axis_ticklabels_overlap(ytl): plt.setp(ytl, rotation="horizontal") # Annotate the cells with the formatted values if annot: mesh.update_scalarmappable() height, width = annot_data.shape xpos, ypos = np.meshgrid(np.arange(width) + 0.5, np.arange(height) + 0.5) for x, y, m, color, val in zip( xpos.flat, ypos.flat, mesh.get_array(), mesh.get_facecolors(), annot_data.flat, ): if m is not lum = _calculate_luminance(color) text_color = ".15" if lum > 0.408 else "w" if not fmt is None: annotation = ("{:" + fmt + "}").format(val) else: annotation = val text_kwargs = dict(color=text_color, ha="center", va="center") text_kwargs.update(annot_kws) ax.text(x, y, annotation, **text_kwargs) ax.xaxis.tick_bottom() ax.yaxis.tick_left() return ax
[docs]class DendrogramPlotter(object): def __init__(self, data=None, linkage=None, metric='correlation', method='average', axis=0, label=True, rotate=False, sizes=None, dendrogram_kws=None): """ Dendrogram plotter. Parameters ---------- data : pd.DataFrame linkage : custom linkage. metric : str method : str axis : int 0 for rows (default) and 1 for columns. label : bool rotate : bool If axis==0 and one would like to plot row dendrogram, rotate should be True sizes : list Sizes of rows or columns to scale x or y coordinate based on given sizes. dendrogram_kws : dict """ self.axis = axis if ( self.axis == 1 # columns dendrogram ): # if not transpose, when calculating dendrogram, each row is a point. data = data.T self.check_array(data) self.shape = self.metric = metric self.method = method self.label = label self.rotate = rotate self.dendrogram_kws = dendrogram_kws if not dendrogram_kws is None else {} if linkage is None: self.linkage = self.calculated_linkage else: self.linkage = linkage if not sizes is None: # used to create dendrogram tree for cluster_between_groups self.sizes = {idx: s for idx, s in zip(, sizes)} else: self.sizes = None if data.shape[0] > 1: self.dendrogram = self.calculate_dendrogram() else: # only 1 row or col self.dendrogram = dict(icoord=[[5, 5, 5, 5]], dcoord=[[0, 1, 1, 0]], ivl=data.index.tolist(), leaves=[0]) # Dendrogram ends are always at multiples of 5, who knows why ticks = np.arange([0]) + 0.5 # xticklabels if self.label: ticklabels = _index_to_ticklabels( ticklabels = [ticklabels[i] for i in self.reordered_ind] if self.rotate: # horizonal self.xticks = [] self.yticks = ticks self.xticklabels = [] self.yticklabels = ticklabels self.ylabel = _index_to_label( self.xlabel = "" else: # vertical self.xticks = ticks self.yticks = [] self.xticklabels = ticklabels self.yticklabels = [] self.ylabel = "" self.xlabel = _index_to_label( else: self.xticks, self.yticks = [], [] self.yticklabels, self.xticklabels = [], [] self.xlabel, self.ylabel = "", "" # self.get_coords()
[docs] def check_array(self, data): if not isinstance(data, pd.DataFrame): data = pd.DataFrame(data) # To avoid missing values and infinite values and further error, remove missing values # nrow = data.shape[0] # keep_col = data.apply(np.isfinite).sum() == nrow # if keep_col.sum() < 3: # raise ValueError("There are too many missing values or infinite values") # data = data.loc[:, keep_col[keep_col].index.tolist()] if data.isna().sum().sum() > 0: data = data.apply(lambda x: x.fillna(x.median()), axis=1) = data self.array = data.values
def _calculate_linkage_scipy(self): # linkage is calculated by columns # print(type(self.array),self.method,self.metric) linkage = hierarchy.linkage(self.array, method=self.method, metric=self.metric) return linkage # array is a distance matrix? def _calculate_linkage_fastcluster(self): import fastcluster # Fastcluster has a memory-saving vectorized version, but only # with certain linkage methods, and mostly with euclidean metric # vector_methods = ('single', 'centroid', 'median', 'ward') euclidean_methods = ("centroid", "median", "ward") euclidean = self.metric == "euclidean" and self.method in euclidean_methods if euclidean or self.method == "single": return fastcluster.linkage_vector( self.array, method=self.method, metric=self.metric ) else: linkage = fastcluster.linkage( self.array, method=self.method, metric=self.metric ) return linkage @property def calculated_linkage(self): try: return self._calculate_linkage_fastcluster() except ImportError: if np.product(self.shape) >= 1000: msg = ( "Clustering large matrix with scipy. Installing " "`fastcluster` may give better performance." ) warnings.warn(msg) return self._calculate_linkage_scipy()
[docs] def calculate_dendrogram(self,): # Z (linkage) shape = (n,4), then dendrogram icoord shape = (n,4) return hierarchy.dendrogram( self.linkage, no_plot=True,, get_leaves=True, **self.dendrogram_kws ) # color_threshold=-np.inf,
@property def reordered_ind(self): """Indices of the matrix, reordered by the dendrogram""" return self.dendrogram["leaves"] # idx of the matrix
[docs] def get_coords(self,ax,gap_pixel=None,root_x=None): self.dependent_coord = np.array(self.dendrogram["dcoord"]) self.independent_coord = np.array(self.dendrogram["icoord"]) / 10 if self.sizes is None: return None icoord_max = self.icoord_max ratio, x_gap = 1, 0 sizes = [self.sizes[idx] for idx in self.dendrogram['ivl']] # sizes is the number of samples in each group cum_sizes = np.cumsum(sizes) if not gap_pixel is None: if self.axis==1: #vertical ax.set_xlim(0, self.icoord_max) real_width = ( ax.get_window_extent().width - (len(self.sizes) - 1) * gap_pixel ) # width: number_of_leaves, real_width:? ratio = real_width / ax.get_window_extent().width # scale the original x x_gap = ( (gap_pixel / ax.get_window_extent().width) * icoord_max ) # x_gap for each gap else: #horizontal ax.set_ylim(0, self.icoord_max) real_height = ( ax.get_window_extent().height - (len(self.sizes) - 1) * gap_pixel ) # width: number_of_leaves, real_width:? ratio = real_height / ax.get_window_extent().height # scale the original x x_gap = ( (gap_pixel / ax.get_window_extent().height) * icoord_max ) xcoord_mapping = {} # map the old independent_coord to new coord # print(sizes,cum_sizes, ratio, x_gap, gap_pixel) icoord=np.unique(self.independent_coord.flatten()) if root_x is None: Frac=[x % 1 for x in icoord] else: # use the root_x from dendrogram for each group root_x=[x / size for x,size in zip(root_x,sizes)] Frac=[root_x[int(x)] if (x - 0.5) % 1 == 0 else x % 1 for x in icoord] # print(icoord,root_x,Frac) for x,frac in zip(icoord,Frac): new_x = frac * sizes[int(x)] * ratio idx = int(x) if idx > 0: new_x += cum_sizes[idx - 1] * ratio + x_gap*idx xcoord_mapping[x] = new_x # for i in range(1,icoord_max+2): # if self.axis==1: # ax.plot([i,i],[0,1],color='red',linewidth=0.5) # ax.plot([i-0.5, i-0.5], [0, 1], color='red', linestyle='--',linewidth=0.2) self.independent_coord = np.array( [[xcoord_mapping[i] for i in a] for a in self.independent_coord] )
[docs] def plot(self, ax, gap_pixel=None, root_x=None,tree_kws=None): """Plots a dendrogram of the similarities between data on the axes Parameters ---------- ax : matplotlib.axes.Axes Axes object upon which the dendrogram is plotted """ tree_kws = {} if tree_kws is None else tree_kws tree_kws.setdefault("linewidth", 0.5) tree_kws.setdefault("colors", None) if self.sizes is None: self.icoord_max = len(self.reordered_ind) else: self.icoord_max = sum([self.sizes[k] for k in self.sizes]) self.get_coords(ax,gap_pixel,root_x) # tree_kws.setdefault("colors", tree_kws.pop("color", (.2, .2, .2))) self.root_x=np.mean(self.independent_coord[-1][1:3]) root_y = np.mean(self.dependent_coord[-1][1:3]) if self.rotate and self.axis == 0: # 0 is rows, 1 is columns (default) coords = zip( self.dependent_coord, self.independent_coord ) # independent is icoord (x), such as 0.5,1.5,2.5,1.25.., horizontal # self.root=(root_y,root_x) # the middle point of the most top level line. else: # axis control whether to tranpose the data, rotate: horizontal or vert coords = zip(self.independent_coord, self.dependent_coord) # vertical # self.root = (root_x, root_y) # lines = LineCollection([list(zip(x,y)) for x,y in coords], **tree_kws) # # ax.add_collection(lines) colors = tree_kws.pop("colors") if colors is None: # colors=self.dendrogram['leaves_color_list'] colors = ["black"] * len(self.dendrogram["ivl"]) if type(colors) == str: colors = [colors] * len(self.dendrogram["ivl"]) for (x, y), color in zip(coords, colors): ax.plot(x, y, color=color, **tree_kws) # if self.axis==1: # ax.scatter(self.root_x, root_y,c='red',s=1) # else: # ax.scatter(root_y,self.root_x, c='red', s=1) if self.rotate: # if axis==0, rotate should be set to True ax.yaxis.set_ticks_position("right") ax.set_ylim(0, self.icoord_max) ax.set_xlim(0, root_y) # before rotate: left -> right, bottom -> top ax.invert_xaxis() # right -> left, root on the left, leaf on the right. ax.invert_yaxis() # top -> bottom, consistent with heatmap pcolormesh else: # vertical, left -> right, bototom -> top. y for leaf is 0, y_root is larger. ax.set_xlim(0, self.icoord_max) ax.set_ylim(0, root_y) despine(ax=ax, bottom=True, left=True) ax.set( xticks=self.xticks, yticks=self.yticks, xlabel=self.xlabel, ylabel=self.ylabel, ) xtl = ax.set_xticklabels(self.xticklabels) ytl = ax.set_yticklabels(self.yticklabels, rotation="vertical") # Force a draw of the plot to avoid matplotlib window error # _draw_figure(ax.figure) if len(ytl) > 0 and axis_ticklabels_overlap(ytl): plt.setp(ytl, rotation="horizontal") if len(xtl) > 0 and axis_ticklabels_overlap(xtl): plt.setp(xtl, rotation="vertical") = ax return self
[docs]class ClusterMapPlotter: """ Clustermap (Heatmap) plotter. Plot heatmap / clustermap with annotation and legends. Parameters ---------- data : dataframe pandas dataframe or numpy array. z_score : int whether to perform z score scale, either 0 for rows or 1 for columns, after scale, value range would be from -1 to 1. standard_scale : int either 0 for rows or 1 for columns, after scale,value range would be from 0 to 1. top_annotation : annotation: class of HeatmapAnnotation. bottom_annotation : class AnnotationBase the same as top_annotation. left_annotation :class AnnotationBase the same as top_annotation. right_annotation :class AnnotationBase the same as top_annotation. row_cluster :bool whether to perform cluster on rows/columns. col_cluster :bool whether to perform cluster on rows/columns. row_cluster_method :str cluster method for row/columns linkage, such single, complete, average,weighted, centroid, median, ward. see scipy.cluster.hierarchy.linkage or ( for detail. row_cluster_metric : str Pairwise distances between observations in n-dimensional space for row/columns, such euclidean, minkowski, cityblock, seuclidean, cosine, correlation, hamming, jaccard, chebyshev, canberra, braycurtis, mahalanobis, kulsinski centroid, median, ward. see scipy.cluster.hierarchy.linkage or Please try metric='canberra' if there are two columns having the same values. col_cluster_method :str same as row_cluster_method col_cluster_metric :str same as row_cluster_metric show_rownames :bool True (default) or False, whether to show row ticklabels. show_colnames : bool True of False, same as show_rownames. row_names_side :str right or left. col_names_side :str top or bottom. row_dendrogram :bool True or False, whether to show dendrogram. col_dendrogram :bool True or False, whether to show dendrogram. row_dendrogram_size :int default is 10mm. col_dendrogram_size :int default is 10mm. row_split :int or pd.Series or pd.DataFrame number of cluster for hierarchical clustering or pd.Series or pd.DataFrame, used to split rows or rows into subplots. col_split :int or pd.Series or pd.DataFrame int or pd.Series or pd.DataFrame, used to split rows or columns into subplots. row_dendrogram_kws :dict custom linkage could be passed to row_dendrogram_kws, for example: row_dendrogram_kws=dict(linkage=my_linkage); Other kws passed to hierarchy.dendrogram. col_dendrogram_kws: dict custom linkage could be passed to col_dendrogram_kws, for example: col_dendrogram_kws=dict(linkage=my_linkage); Other kws passed to hierarchy.dendrogram. tree_kws :dict kws passed to DendrogramPlotter.plot() row_split_order: list or str a list to specify the order of row_split, could also be 'cluster_between_groups', if cluster_between_groups was specified, hierarchical clustering will be performed on the mean values for each groups and pass the clsutered order to row_split_order. For example, see col_split_order: list or str a list to specify the order of col_split, could also be 'cluster_between_groups', if cluster_between_groups was specified, hierarchical clustering will be performed on the mean values for each groups and pass the clsutered order to row_split_order. row_split_gap :float default are 0.5 and 0.2 mm for row and col. col_split_gap :float default are 0.5 and 0.2 mm for row and col. mask :dataframe or array mask the data in heatmap, the cell with missing values of infinite values will be masked automatically. subplot_gap :float the gap between subplots, default is 1mm. legend :bool True or False, whether to plot heatmap legend, determined by cmap. legend_kws :dict vmax, vmin and other kws passed to plot legend, such asfontsize, fontsize, labelcolor, numpoints, markerscale, markerfirst, frameon shadow, facecolor, edgecolor, title, title_fontsize, labelspacing and so on (see ?plt.legend) Alaternatively, we can also change the outline color and linewidth of cbar after plotting: cm=ClusterMapPlotter(...) for cbar in cm.cbars: if isinstance(cbar,matplotlib.colorbar.Colorbar): cbar.outline.set_color('white') cbar.outline.set_linewidth(2) cbar.dividers.set_color('red') cbar.dividers.set_linewidth(2) plot :bool whether to plot or not. plot_legend :bool True or False, whether to plot legend, if False, legends can be plot with ClusterMapPlotter.plot_legends() legend_anchor :str ax_heatmap or ax, the ax to which legend anchor. legend_gap :float the columns gap between different legends. legend_width: float [mm] width of the legend, default is None (infer from data automatically) legend_hpad: float Horizonal space between the heatmap and legend, default is 2 [mm]. legend_vpad: float Vertical space between the top of legend_anchor and legend, default is 5 [mm]. legend_side :str right of left. cmap :str default is 'jet', the colormap for heatmap colorbar, see plt.colormaps(). label :str the title (label) that will be shown in heatmap colorbar legend. xticklabels_kws :dict xticklabels or yticklabels kws, such as axis, which, direction, length, width, color, pad, labelsize, labelcolor, colors, zorder, bottom, top, left, right, labelbottom, labeltop, labelleft, labelright, labelrotation, grid_color, grid_linestyle and so on. For more information,see ?matplotlib.axes.Axes.tick_params or ?ax.tick_params. yticklabels_kws :dict the same as xticklabels_kws. xlabel: str default is None (no xlabel would be shown). ylabel: str default is None (no ylabel would be shown). xlabel_kws: dict alpha,color,fontfamily,fontname,fontproperties,fontsize,fontstyle, fontweight,label,rasterized,rotation,rotation_mode(default,anchor),visible, zorder,verticalalignment,horizontalalignment. See, for example: cm=ClusterMapPlotter(***), print( or for detail. ylabel_kws: dict sams as xlabel_kws xlabel_side: str bottom or top, default is bottom, ylabel_side: str left or right, default is left xlabel_bbox_kws: dict alpha,clip_box, clip_on,edgecolor,facecolor,fill,height,in_layout,label, linestyle, linewidth,rasterized,visible,width. See ax.xaxis.label.get_bbox_patch().properties() for more information. For example: cm=ClusterMapPlotter(***), print( ylabel_bbox_kws: dict same as xlabel_bbox_kws rasterized :bool default is auto, when the number of rows or number of cols > 5000, rasterized would be automatically set to True to speed up the plotting. kwargs :kws passed to plot_heatmap, including vmin, vmax,center,robust, annot, annot_kws, fmt, mask, linewidths linecolor, na_col, cbar,cbar_kwss ,xticklabels/yticklabels and so on (see ?PyComplexHeatmap.clustermap.plot_heatmap). If annot is True, the values of data will be plotted on the top of heatmap, if annot is a dataframe, then the custom values will be plotted on heatmap, fmt should be set to None if dtype of annot is str. For documentation of custom annot, see xticklabels/yticklabels will be shown automatically, if the width/height is too small to display all xticklabels, not all ticklabels will be shown (to avoid overlap). To force display all ticklabels, set xticklabels/yticklabels to True. Returns ------- Class ClusterMapPlotter. """ def __init__( self, data, z_score=None, standard_scale=None, top_annotation=None, bottom_annotation=None, left_annotation=None, right_annotation=None, row_cluster=True, col_cluster=True, row_cluster_method="average", row_cluster_metric="correlation", col_cluster_method="average", col_cluster_metric="correlation", show_rownames=False, show_colnames=False, row_names_side="right", col_names_side="bottom", xticklabels_kws=None, yticklabels_kws=None, row_dendrogram=False, col_dendrogram=False, row_dendrogram_size=10, col_dendrogram_size=10, row_split=None, col_split=None, row_dendrogram_kws=None, col_dendrogram_kws=None, tree_kws=None, row_split_order=None, col_split_order=None, row_split_gap=0.5, col_split_gap=0.2, mask=None, subplot_gap=1, legend=True, legend_kws=None, plot=True, plot_legend=True, legend_anchor="auto", legend_gap=7, legend_width=None, legend_hpad=1, legend_vpad=5, legend_side="right", cmap="jet", label=None, xlabel=None, ylabel=None, xlabel_kws=None, ylabel_kws=None, xlabel_side='bottom', ylabel_side='left', xlabel_bbox_kws=None, ylabel_bbox_kws=None, rasterized='auto', legend_delta_x=None, verbose=1, **kwargs ): self.kwargs = kwargs if not kwargs is None else {} self.rasterized = rasterized self.data2d = self.format_data(data, mask, z_score, standard_scale) self.verbose = verbose self._define_kws(xticklabels_kws, yticklabels_kws) self.top_annotation = top_annotation self.bottom_annotation = bottom_annotation self.left_annotation = left_annotation self.right_annotation = right_annotation self.row_dendrogram_size = row_dendrogram_size self.col_dendrogram_size = col_dendrogram_size self.row_cluster = row_cluster self.col_cluster = col_cluster self.row_cluster_method = row_cluster_method self.row_cluster_metric = row_cluster_metric self.col_cluster_method = col_cluster_method self.col_cluster_metric = col_cluster_metric self.show_rownames = show_rownames self.show_colnames = show_colnames self.row_names_side = row_names_side self.col_names_side = col_names_side self.row_dendrogram = row_dendrogram self.col_dendrogram = col_dendrogram self.subplot_gap = subplot_gap self.row_dendrogram_kws = {} if row_dendrogram_kws is None else row_dendrogram_kws self.col_dendrogram_kws = {} if col_dendrogram_kws is None else col_dendrogram_kws self.tree_kws = {} if tree_kws is None else tree_kws self.row_split = row_split self.col_split = col_split self.row_split_gap = row_split_gap self.col_split_gap = col_split_gap self.row_split_order = row_split_order self.col_split_order = col_split_order self.legend = legend self.legend_kws = legend_kws if not legend_kws is None else {} self.legend_side = legend_side self.cmap = cmap self.label = label if not label is None else "heatmap" self.legend_gap = legend_gap self.legend_width = legend_width self.legend_hpad = legend_hpad self.legend_vpad = legend_vpad self.legend_anchor = legend_anchor self.legend_delta_x = legend_delta_x self.xlabel = xlabel self.ylabel = ylabel self.xlabel_kws = xlabel_kws if not xlabel_kws is None else {} self.ylabel_kws = ylabel_kws if not ylabel_kws is None else {} self.xlabel_side = xlabel_side self.ylabel_side = ylabel_side self.xlabel_bbox_kws = xlabel_bbox_kws self.ylabel_bbox_kws = ylabel_bbox_kws if plot: self.plot() if plot_legend: if legend_anchor == "auto": if ( not self.right_annotation is None and self.legend_side == "right" ): legend_anchor = "ax" else: legend_anchor = "ax_heatmap" if legend_anchor == "ax_heatmap": self.plot_legends(ax=self.ax_heatmap) else: self.plot_legends( self.post_processing() def _define_kws(self, xticklabels_kws, yticklabels_kws): self.yticklabels_kws = {} if yticklabels_kws is None else yticklabels_kws # self.yticklabels_kws.setdefault('labelrotation', 0) self.xticklabels_kws = {} if xticklabels_kws is None else xticklabels_kws # self.xticklabels_kws.setdefault('labelrotation', 90)
[docs] def format_data(self, data, mask=None, z_score=None, standard_scale=None): data2d = data.copy() if z_score is not None and standard_scale is not None: raise ValueError( "Cannot perform both z-scoring and standard-scaling on data" ) if z_score is not None: data2d = self.z_score(data, z_score) if standard_scale is not None: data2d = self.standard_scale(data, standard_scale) self.kwargs.setdefault("vmin", np.nanmin(data2d)) self.kwargs.setdefault("vmax", np.nanmax(data2d)) self.kwargs.setdefault('xticklabels',"auto") self.kwargs.setdefault('yticklabels', "auto") self.mask = _check_mask(data2d, mask) if self.rasterized == 'auto': if max(data2d.shape[0], data2d.shape[1]) > 5000: self.rasterized = True else: self.rasterized = False return data2d
def _define_gs_ratio(self): self.top_heights = [] self.bottom_heights = [] self.left_widths = [] self.right_widths = [] if self.col_dendrogram: self.top_heights.append( self.col_dendrogram_size * mm2inch * ) if self.row_dendrogram: self.left_widths.append( self.row_dendrogram_size * mm2inch * ) if not self.top_annotation is None: self.top_heights.append( sum(self.top_annotation.heights) * mm2inch * ) else: self.top_heights.append(0) if not self.left_annotation is None: self.left_widths.append( sum(self.left_annotation.heights) * mm2inch * ) else: self.left_widths.append(0) if not self.bottom_annotation is None: self.bottom_heights.append( sum(self.bottom_annotation.heights) * mm2inch * ) else: self.bottom_heights.append(0) if not self.right_annotation is None: self.right_widths.append( sum(self.right_annotation.heights) * mm2inch * ) else: self.right_widths.append(0) heatmap_h = ( - sum(self.top_heights) - sum(self.bottom_heights) ) heatmap_w = ( - sum(self.left_widths) - sum(self.right_widths) ) self.heights = [sum(self.top_heights), heatmap_h, sum(self.bottom_heights)] self.widths = [sum(self.left_widths), heatmap_w, sum(self.right_widths)] def _define_axes(self, subplot_spec=None): wspace = ( self.subplot_gap * mm2inch * / ( / 3) ) hspace = ( self.subplot_gap * mm2inch * / ( / 3) ) if subplot_spec is None: = 3, 3, width_ratios=self.widths, height_ratios=self.heights, wspace=wspace, hspace=hspace, ) else: = matplotlib.gridspec.GridSpecFromSubplotSpec( 3, 3, width_ratios=self.widths, height_ratios=self.heights, wspace=wspace, hspace=hspace, subplot_spec=subplot_spec, ) # left -> right, top -> bottom self.ax_heatmap =[1, 1]) self.ax_top =[0, 1], sharex=self.ax_heatmap) self.ax_bottom =[2, 1], sharex=self.ax_heatmap ) self.ax_left =[1, 0], sharey=self.ax_heatmap) self.ax_right =[1, 2], sharey=self.ax_heatmap ) self.ax_heatmap.set_xlim([0, self.data2d.shape[1]]) self.ax_heatmap.set_ylim([0, self.data2d.shape[0]]) # self.ax_heatmap.yaxis.set_visible(False) self.ax_heatmap.xaxis.set_visible(False) axis="both", which="both", left=False, right=False, labelleft=False, labelright=False, top=False, bottom=False, labeltop=False, labelbottom=False, ) self.ax_heatmap.tick_params( axis="both", which="both", left=False, right=False, top=False, bottom=False, labeltop=False, labelbottom=False, labelleft=False, labelright=False, ) # # add xlabel and ylabel, to comment out this line. despine(, bottom=True, left=True, top=True, right=True) #,right=right,top=top,bottom=bottom) #,0,1,1,hspace=0.1,wspace=0) #, 0.03) # _draw_figure( #,h_pad=0.0,w_pad=0,pad=0) def _define_top_axes(self): self.top_gs = None if self.top_annotation is None and self.col_dendrogram: self.ax_col_dendrogram = self.ax_top self.ax_top_annotation = None elif self.top_annotation is None and not self.col_dendrogram: self.ax_top_annotation = None self.ax_col_dendrogram = None elif self.col_dendrogram: self.top_gs = matplotlib.gridspec.GridSpecFromSubplotSpec( 2, 1, hspace=0, wspace=0,[0, 1], height_ratios=[ self.col_dendrogram_size, sum(self.top_annotation.heights), ], ) self.ax_top_annotation = self.ax_top.figure.add_subplot(self.top_gs[1, 0]) self.ax_col_dendrogram = self.ax_top.figure.add_subplot(self.top_gs[0, 0]) else: self.ax_top_annotation = self.ax_top self.ax_col_dendrogram = None self.ax_top.set_axis_off() def _define_left_axes(self): self.left_gs = None if self.left_annotation is None and self.row_dendrogram: self.ax_row_dendrogram = self.ax_left self.ax_left_annotation = None elif self.left_annotation is None and not self.row_dendrogram: self.ax_left_annotation = None self.ax_row_dendrogram = None elif self.row_dendrogram: self.left_gs = matplotlib.gridspec.GridSpecFromSubplotSpec( 1, 2, hspace=0, wspace=0,[1, 0], width_ratios=[ self.row_dendrogram_size, sum(self.left_annotation.heights), ], ) # 1 row, 2columns (row dendrogram + left annotation) self.ax_left_annotation = self.ax_left.figure.add_subplot( self.left_gs[0, 1] ) self.ax_row_dendrogram = self.ax_left.figure.add_subplot(self.left_gs[0, 0]) self.ax_row_dendrogram.set_axis_off() else: self.ax_left_annotation = self.ax_left self.ax_row_dendrogram = None self.ax_left.set_axis_off() def _define_bottom_axes(self): if self.bottom_annotation is None: self.ax_bottom_annotation = None else: self.ax_bottom_annotation = self.ax_bottom self.ax_bottom.set_axis_off() def _define_right_axes(self): if self.right_annotation is None: self.ax_right_annotation = None else: self.ax_right_annotation = self.ax_right self.ax_right.set_axis_off()
[docs] @staticmethod def z_score(data2d, axis=1): """ Standarize the mean and variance of the data axis Parameters ---------- data2d : pandas.DataFrame Data to normalize axis : int Which axis to normalize across. If 0, normalize across rows, if 1, normalize across columns. Returns ------- normalized : pandas.DataFrame Noramlized data with a mean of 0 and variance of 1 across the specified axis. """ if axis == 1: z_scored = data2d else: z_scored = data2d.T z_scored = (z_scored - z_scored.mean()) / z_scored.std() if axis == 1: return z_scored else: return z_scored.T
[docs] @staticmethod def standard_scale(data2d, axis=1): """ Divide the data by the difference between the max and min Parameters ---------- data2d : pandas.DataFrame Data to normalize axis : int Which axis to normalize across. If 0, normalize across rows, if 1, normalize across columns. Returns ------- standardized : pandas.DataFrame Noramlized data with a mean of 0 and variance of 1 across the specified axis. """ # Normalize these values to range from 0 to 1 if axis == 1: standardized = data2d else: standardized = data2d.T subtract = standardized.min() standardized = (standardized - subtract) / ( standardized.max() - standardized.min() ) if axis == 1: return standardized else: return standardized.T
[docs] def calculate_row_dendrograms(self, data, sizes=None,use_linkage=True): row_dendrogram_kws=self.row_dendrogram_kws.copy() if use_linkage: linkage=row_dendrogram_kws.pop('linkage',None) else: linkage=None self.dendrogram_row = DendrogramPlotter( data, linkage=linkage, axis=0, metric=self.row_cluster_metric, method=self.row_cluster_method, label=False, rotate=True, sizes=sizes, dendrogram_kws=row_dendrogram_kws, ) if not self.ax_row_dendrogram is None: self.ax_row_dendrogram.set_axis_off()
[docs] def calculate_col_dendrograms(self, data, sizes=None,use_linkage=True): col_dendrogram_kws = self.col_dendrogram_kws.copy() if use_linkage: linkage = col_dendrogram_kws.pop('linkage', None) else: linkage = None self.dendrogram_col = DendrogramPlotter( data, linkage=linkage, axis=1, metric=self.col_cluster_metric, method=self.col_cluster_method, label=False, rotate=False, sizes=sizes, dendrogram_kws=col_dendrogram_kws, ) # self.dendrogram_col.plot(ax=self.ax_col_dendrogram) # despine(ax=self.ax_col_dendrogram, bottom=True, left=True, top=True, right=True) if not self.ax_col_dendrogram is None: self.ax_col_dendrogram.set_axis_off()
[docs] def cal_rowd_between_groups(self,row_clusters): mat = pd.concat([ self.data2d.loc[rows].mean() for rows in row_clusters.tolist()], axis=1).T # columns are original columns mat.index = row_clusters.index.tolist() sizes = row_clusters.apply(lambda x: len(x)).tolist() self.calculate_row_dendrograms(mat, sizes=sizes)
[docs] def cal_cold_between_groups(self,col_clusters): mat = pd.concat([ self.data2d.loc[:, cols].mean(axis=1) for cols in col_clusters.tolist()], axis=1) # index are original rows labels mat.columns = col_clusters.index.tolist() sizes = col_clusters.apply(lambda x: len(x)).tolist() self.calculate_col_dendrograms(mat, sizes=sizes)
def _reorder_rows(self): self.row_split_dendrogram = False if self.verbose >= 1: print("Reordering rows..") if self.row_split is None and self.row_cluster: self.calculate_row_dendrograms( self.data2d ) # xind=self.dendrogram_row.reordered_ind self.row_order = [ self.dendrogram_row.dendrogram["ivl"] ] # self.data2d.iloc[:, xind].columns.tolist() return None if isinstance(self.row_split, int) and self.row_cluster: self.calculate_row_dendrograms(self.data2d) #generate self.dendrogram_row row_clusters = ( pd.Series( hierarchy.fcluster( self.dendrogram_row.linkage, t=self.row_split, criterion="maxclust", ), index=self.data2d.index.tolist(), ) .to_frame(name="cluster") .groupby("cluster") .apply(lambda x: x.index.tolist()) ) self.cal_rowd_between_groups(row_clusters) row_split_order = self.dendrogram_row.dendrogram["ivl"] self.row_split_dendrogram = self.dendrogram_row self.row_clusters = row_clusters.loc[row_split_order].to_dict() elif isinstance(self.row_split, (pd.Series, pd.DataFrame)): if isinstance(self.row_split, pd.Series): self.row_split = self.row_split.to_frame( cols = self.row_split.columns.tolist() row_clusters = self.row_split.groupby(cols).apply( lambda x: x.index.tolist() ) if ( self.row_split_order is None or self.row_split_order == 'cluster_between_groups' ): self.cal_rowd_between_groups(row_clusters) self.row_split_dendrogram = self.dendrogram_row #dendrogram between groups row_split_order = self.dendrogram_row.dendrogram["ivl"] else: row_split_order=self.row_split_order self.row_clusters = row_clusters.loc[row_split_order].to_dict() elif not self.row_cluster: self.row_order = [self.data2d.index.tolist()] return None else: raise TypeError("row_split must be integar or dataframe or series") self.row_order = [] self.dendrogram_rows = [] #store DendrogramPlotter object for all row clusters for i, cluster in enumerate(self.row_clusters): rows = self.row_clusters[cluster] if len(rows) <= 1 and isinstance(self.row_split_order,(list,np.ndarray)): self.row_order.append(rows) self.dendrogram_rows.append(None) continue if self.row_cluster: # cluster within groups self.calculate_row_dendrograms(self.data2d.loc[rows]) self.dendrogram_rows.append(self.dendrogram_row) self.row_order.append(self.dendrogram_row.dendrogram["ivl"]) else: self.row_order.append(rows) def _reorder_cols(self): self.col_split_dendrogram = False if self.verbose >= 1: print("Reordering cols..") if self.col_split is None and self.col_cluster: self.calculate_col_dendrograms(self.data2d) self.col_order = [ self.dendrogram_col.dendrogram["ivl"] ] # self.data2d.iloc[:, xind].columns.tolist() return None if isinstance(self.col_split, int) and self.col_cluster: self.calculate_col_dendrograms(self.data2d) col_clusters = ( pd.Series( hierarchy.fcluster( self.dendrogram_col.linkage, t=self.col_split, criterion="maxclust", ), index=self.data2d.columns.tolist(), ) .to_frame(name="cluster") .groupby("cluster") .apply(lambda x: x.index.tolist()) ) self.cal_cold_between_groups(col_clusters) col_split_order = self.dendrogram_col.dendrogram["ivl"] self.col_split_dendrogram = self.dendrogram_col self.col_clusters = col_clusters.loc[col_split_order].to_dict() elif isinstance(self.col_split, (pd.Series, pd.DataFrame)): if isinstance(self.col_split, pd.Series): self.col_split = self.col_split.to_frame( cols = self.col_split.columns.tolist() col_clusters = self.col_split.groupby(cols).apply( lambda x: x.index.tolist() ) if ( self.col_split_order is None or self.col_split_order == 'cluster_between_groups' ): self.cal_cold_between_groups(col_clusters) self.col_split_dendrogram = self.dendrogram_col col_split_order = self.dendrogram_col.dendrogram["ivl"] else: col_split_order=self.col_split_order self.col_clusters = col_clusters.loc[col_split_order].to_dict() elif not self.col_cluster: self.col_order = [self.data2d.columns.tolist()] return None else: raise TypeError("row_split must be integar or dataframe or series") self.col_order = [] self.dendrogram_cols = [] for i, cluster in enumerate(self.col_clusters): cols = self.col_clusters[cluster] if len(cols) <= 1 and isinstance(self.col_split_order,(list,np.ndarray)): self.col_order.append(cols) self.dendrogram_cols.append(None) #only one column continue if self.col_cluster: self.calculate_col_dendrograms(self.data2d.loc[:, cols]) self.dendrogram_cols.append(self.dendrogram_col) self.col_order.append(self.dendrogram_col.dendrogram["ivl"]) else: self.col_order.append(cols)
[docs] def plot_dendrograms(self, row_order, col_order): rcmap = self.tree_kws.pop("row_cmap", None) ccmap = self.tree_kws.pop("col_cmap", None) if ( self.row_split_order == 'cluster_between_groups' and self.row_split_dendrogram and self.row_dendrogram ): self.row_split_dendrogram.plot( ax=self.ax_row_dendrogram, gap_pixel=self.row_split_gap_pixel, tree_kws=self.tree_kws.copy()) elif self.row_cluster and self.row_dendrogram: if self.left_annotation is None: gs =[1, 0] else: gs = self.left_gs[0, 0] ncols = 2 if len(row_order) > 1 and self.row_split_dendrogram else 1 self.row_dendrogram_gs = matplotlib.gridspec.GridSpecFromSubplotSpec( len(row_order), ncols, hspace=self.hspace, wspace=0, subplot_spec=gs, height_ratios=[len(rows) for rows in row_order], # width_ratios=width_ratios, ) #add another column of axes (on the left) to plot self.row_split_dendrogram self.ax_row_dendrogram_axes = [] for i in range(len(row_order)): ax1 = self.ax_row_dendrogram.figure.add_subplot( self.row_dendrogram_gs[i, -1] ) ax1.set_axis_off() self.ax_row_dendrogram_axes.append(ax1) try: n=len(self.dendrogram_rows) tree_kws = self.tree_kws.copy() if rcmap is None: if 'colors' not in self.tree_kws: color = 'black' else: color = self.tree_kws['colors'] colors = [color] * n else: colors = [ get_colormap(rcmap)(i) for i in range(len(self.dendrogram_rows)) ] for ax_row_dendrogram, dendrogram_row, color in zip( self.ax_row_dendrogram_axes, self.dendrogram_rows, colors ): if dendrogram_row is None: continue tree_kws["colors"] = [color] * len(dendrogram_row.dendrogram["ivl"]) dendrogram_row.plot(ax=ax_row_dendrogram, tree_kws=tree_kws) except: #self.dendrogram_rows does not existed, because row_split is None self.dendrogram_row.plot( ax=self.ax_row_dendrogram, tree_kws=self.tree_kws ) if ncols > 1 and self.row_split_dendrogram: #plot extra parent self.row_split_dendrogram if 'colors' not in self.tree_kws: color = 'black' else: color = self.tree_kws['colors'] tree_kws['colors'] = color root_x=[dendrogram_row.root_x for dendrogram_row in self.dendrogram_rows] self.ax_row_split_dendrogram=self.ax_row_dendrogram.figure.add_subplot( self.row_dendrogram_gs[:, 0] ) self.ax_row_split_dendrogram.set_axis_off() self.row_split_dendrogram.plot( ax=self.ax_row_split_dendrogram, gap_pixel=self.row_split_gap_pixel, root_x=root_x, tree_kws=tree_kws) if (self.col_split_order == 'cluster_between_groups' and self.col_split_dendrogram and self.col_dendrogram ): self.col_split_dendrogram.plot( ax=self.ax_col_dendrogram, gap_pixel=self.col_split_gap_pixel, tree_kws=self.tree_kws.copy()) elif self.col_cluster and self.col_dendrogram: if self.top_annotation is None: gs =[0, 1] else: gs = self.top_gs[0, 0] nrows = 2 if len(col_order) > 1 and self.col_split_dendrogram else 1 #how many rows in col_dendrogram_ax # height_ratios = None if nrows == 1 else [1, 2] self.col_dendrogram_gs = matplotlib.gridspec.GridSpecFromSubplotSpec( nrows, len(col_order), hspace=0, wspace=self.wspace, subplot_spec=gs, width_ratios=[len(cols) for cols in col_order], # height_ratios=height_ratios ) self.ax_col_dendrogram_axes = [] for i in range(len(col_order)): ax1 = self.ax_col_dendrogram.figure.add_subplot( self.col_dendrogram_gs[-1, i] ) ax1.set_axis_off() self.ax_col_dendrogram_axes.append(ax1) try: n=len(self.dendrogram_cols) tree_kws = self.tree_kws.copy() if ccmap is None: if 'colors' not in tree_kws: color = 'black' else: color = tree_kws['colors'] colors = [color] *n else: colors = [ get_colormap(ccmap)(i) for i in range(n) ] for ax_col_dendrogram, dendrogram_col, color in zip( self.ax_col_dendrogram_axes, self.dendrogram_cols, colors ): if dendrogram_col is None: continue tree_kws["colors"] = [color] * len(dendrogram_col.dendrogram["ivl"]) dendrogram_col.plot(ax=ax_col_dendrogram, tree_kws=tree_kws) except: self.dendrogram_col.plot( ax=self.ax_col_dendrogram, tree_kws=self.tree_kws ) if nrows > 1 and self.col_split_dendrogram: #plot between groups dendrogram if 'colors' not in self.tree_kws: color = 'black' else: color = self.tree_kws['colors'] tree_kws['colors'] = color root_x = [dendrogram_col.root_x for dendrogram_col in self.dendrogram_cols] self.ax_col_split_dendrogram=self.ax_col_dendrogram.figure.add_subplot( self.col_dendrogram_gs[0, :] ) self.ax_col_split_dendrogram.set_axis_off() self.col_split_dendrogram.plot( ax=self.ax_col_split_dendrogram, gap_pixel=self.col_split_gap_pixel, root_x=root_x, tree_kws=tree_kws)
[docs] def plot_matrix(self, row_order, col_order): if self.verbose >= 1: print("Plotting matrix..") nrows = len(row_order) ncols = len(col_order) self.col_split_gap_pixel = self.col_split_gap * mm2inch * # self.wspace = ( # self.col_split_gap_pixel # / (self.ax_heatmap.get_window_extent().width / ncols) # ) # 1mm=mm2inch inch; pixels divided by average pixels self.wspace=( (self.col_split_gap_pixel * ncols) / ( self.ax_heatmap.get_window_extent().width + self.col_split_gap_pixel - self.col_split_gap_pixel*ncols ) ) # wspace: The amount of width reserved for space between subplots, # expressed as a fraction of the average axis width. # 20231130: refer to: self.row_split_gap_pixel = self.row_split_gap * mm2inch * # self.hspace = ( # self.row_split_gap_pixel # / (self.ax_heatmap.get_window_extent().height / nrows) # ) # height self.hspace = ( (self.row_split_gap_pixel * nrows) / ( self.ax_heatmap.get_window_extent().height + self.row_split_gap_pixel - self.row_split_gap_pixel * nrows ) ) self.heatmap_gs = matplotlib.gridspec.GridSpecFromSubplotSpec( nrows, ncols, hspace=self.hspace, wspace=self.wspace,[1, 1], height_ratios=[len(rows) for rows in row_order], width_ratios=[len(cols) for cols in col_order], ) annot = self.kwargs.pop("annot", None) if annot is None or annot is False: pass else: if isinstance(annot, bool): annot_data = self.data2d else: annot_data = annot.copy() if annot_data.shape != self.data2d.shape: err = "`data` and `annot` must have same shape." raise ValueError(err) self.heatmap_axes = np.empty(shape=(nrows, ncols), dtype=object) # if nrows > 1 or ncols > 1: self.ax_heatmap.set_axis_off() for i, rows in enumerate(row_order): for j, cols in enumerate(col_order): # print(i,j) ax1 = self.ax_heatmap.figure.add_subplot( self.heatmap_gs[i, j], sharex=self.heatmap_axes[0, j], sharey=self.heatmap_axes[i, 0], ) ax1.set_xlim([0, len(rows)]) ax1.set_ylim([0, len(cols)]) annot1 = None if annot is None else annot_data.loc[rows, cols] # heatmap(self.data2d.loc[rows, cols], ax=ax1, cbar=False, cmap=self.cmap, # cbar_kws=None, mask=self.mask.loc[rows, cols], rasterized=self.rasterized, # xticklabels='auto', yticklabels='auto', annot=annot1, **self.kwargs) plot_heatmap( self.data2d.loc[rows, cols], ax=ax1, cmap=self.cmap, mask=self.mask.loc[rows, cols], rasterized=self.rasterized, annot=annot1, **self.kwargs ) self.heatmap_axes[i, j] = ax1 ax1.yaxis.label.set_visible(False) ax1.xaxis.label.set_visible(False) ax1.tick_params( left=False, right=False, labelleft=False, labelright=False, top=False, bottom=False, labeltop=False, labelbottom=False, )
[docs] def set_axes_labels_kws(self): # ax.set_xticks(ticks=np.arange(1, self.nrows + 1, 1), labels=self.plot_data.index.tolist()) self.ax_heatmap.yaxis.set_tick_params(**self.yticklabels_kws) self.ax_heatmap.xaxis.set_tick_params(**self.xticklabels_kws) self.yticklabels = [] self.xticklabels = [] if ( self.show_rownames and self.left_annotation is None and not self.row_dendrogram ) and ( (not self.right_annotation is None) or (self.right_annotation is None and self.row_names_side == "left") ): # tick left self.row_names_side = "left" self.yticklabels_kws.setdefault("labelrotation", 0) for i in range(self.heatmap_axes.shape[0]): self.heatmap_axes[i, 0].yaxis.set_visible(True) self.heatmap_axes[i, 0].tick_params( axis="y", which="both", left=False, labelleft=True ) self.heatmap_axes[i, 0].yaxis.set_tick_params( **self.yticklabels_kws ) # **self.ticklabels_kws plt.setp( self.heatmap_axes[i, 0].get_yticklabels(), rotation_mode="anchor", ha="right", va="center", ) self.yticklabels.extend(self.heatmap_axes[i, 0].get_yticklabels()) elif self.show_rownames and self.right_annotation is None: # tick right self.row_names_side = "right" self.yticklabels_kws.setdefault("labelrotation", 0) for i in range(self.heatmap_axes.shape[0]): self.heatmap_axes[ i, -1 ].yaxis.tick_right() # set_ticks_position('right') self.heatmap_axes[i, -1].yaxis.set_visible(True) self.heatmap_axes[i, -1].tick_params( axis="y", which="both", right=False, labelright=True ) self.heatmap_axes[i, -1].yaxis.set_tick_params(**self.yticklabels_kws) plt.setp( self.heatmap_axes[i, -1].get_yticklabels(), rotation_mode="anchor", ha="left", va="center", ) self.yticklabels.extend(self.heatmap_axes[i, -1].get_yticklabels()) if ( self.show_colnames and self.top_annotation is None and not self.col_dendrogram and ( (not self.bottom_annotation is None) or (self.bottom_annotation is None and self.col_names_side == "top") ) ): self.xticklabels_kws.setdefault("labelrotation", 90) for j in range(self.heatmap_axes.shape[1]): self.heatmap_axes[0, j].xaxis.tick_top() # ticks self.heatmap_axes[0, j].xaxis.set_visible(True) self.heatmap_axes[0, j].tick_params( axis="x", which="both", top=False, labeltop=True ) self.heatmap_axes[0, j].xaxis.set_tick_params(**self.xticklabels_kws) plt.setp( self.heatmap_axes[0, j].get_xticklabels(), rotation_mode="anchor", ha="left", va="center", ) # rotation=90,ha=left is bottom, va is horizonal self.xticklabels.extend(self.heatmap_axes[0, j].get_xticklabels()) elif self.show_colnames and self.bottom_annotation is None: # tick bottom self.xticklabels_kws.setdefault("labelrotation", -90) for j in range(self.heatmap_axes.shape[1]): self.heatmap_axes[-1, j].xaxis.tick_bottom() # ticks self.heatmap_axes[-1, j].xaxis.set_visible(True) self.heatmap_axes[-1, j].tick_params( axis="x", which="both", bottom=False, labelbottom=True ) self.heatmap_axes[-1, j].xaxis.set_tick_params(**self.xticklabels_kws) plt.setp( self.heatmap_axes[-1, j].get_xticklabels(), rotation_mode="anchor", ha="left", va="center", ) self.xticklabels.extend(self.heatmap_axes[-1, j].get_xticklabels())
[docs] def set_xy_labels(self): if not self.xlabel is None: if (len(self.xticklabels) > 0 and self.xlabel_side == self.col_names_side == 'bottom' ): xticklabel_h = max( [label.get_window_extent().height for label in self.xticklabels] ) # unit is pixel else: xticklabel_h = 0 if 'labelpad' not in self.xlabel_kws: labelpad = xticklabel_h * 72 / # points; pixel to points: 1 point == fig.dpi/72. pixels else: labelpad = self.xlabel_kws.pop('labelpad') if 'position' not in self.xlabel_kws: ax_heatmap_box = self.ax_heatmap.get_window_extent() ax_box = pos_x=(np.mean([ax_heatmap_box.x0,ax_heatmap_box.x1])-ax_box.x0) / ax_box.width self.xlabel_kws.setdefault('position',(pos_x,0)) self.xlabel_kws.setdefault("verticalalignment", "center"), labelpad=labelpad + if not self.xlabel_bbox_kws is None: self.xlabel_bbox_kws ) # ax.xaxis.label.get_bbox_patch().properties() if not self.ylabel is None: if len(self.yticklabels) > 0 and self.row_names_side == self.ylabel_side: yticklabel_w = max( [label.get_window_extent().width for label in self.yticklabels] ) # unit is pixel else: yticklabel_w = 0 if 'labelpad' not in self.ylabel_kws: labelpad = yticklabel_w * 72 / # points; pixel to points: 1 point == fig.dpi/72. pixels else: labelpad = self.ylabel_kws.pop('labelpad') if 'position' not in self.xlabel_kws: ax_heatmap_box = self.ax_heatmap.get_window_extent() ax_box = pos_y=(np.mean([ax_heatmap_box.y0,ax_heatmap_box.y1])-ax_box.y0) / ax_box.height self.ylabel_kws.setdefault('position',(0,pos_y)) self.ylabel_kws.setdefault("horizontalalignment", "center"), labelpad=labelpad + if not self.ylabel_bbox_kws is None: self.ylabel_bbox_kws ) # ax.xaxis.label.get_bbox_patch().properties()
[docs] def collect_legends(self): if self.verbose >= 1: print("Collecting legends..") self.legend_list = [] self.label_max_width = 0 for annotation in [ self.top_annotation, self.bottom_annotation, self.left_annotation, self.right_annotation, ]: if not annotation is None: annotation.collect_legends() if annotation.plot_legend and len(annotation.legend_list) > 0: self.legend_list.extend(annotation.legend_list) # print(annotation.label_max_width,self.label_max_width) if annotation.label_max_width > self.label_max_width: self.label_max_width = annotation.label_max_width if self.legend: # vmax = self.kwargs.get( # "vmax", np.nanmax(self.data2d[self.data2d != np.inf]) # ) # vmin = self.kwargs.get( # "vmin", np.nanmin(self.data2d[self.data2d != -np.inf]) # ) self.legend_kws.setdefault("vmin", self.kwargs.get('vmin')) #round(vmin, 2)) self.legend_kws.setdefault("vmax", self.kwargs.get('vmax')) #round(vmax, 2)) self.legend_kws.setdefault("center", self.kwargs.get('center',None)) self.legend_list.append([self.cmap, self.label, self.legend_kws, 4, "cmap"]) if len(self.yticklabels) > 0 and self.row_names_side == "right": max_yticklabel_w = max( [label.get_window_extent().width for label in self.yticklabels] ) else: max_yticklabel_w = 0 if not self.ylabel is None and self.ylabel_side == 'right': ylabel_w = else: ylabel_w = 0 if self.row_names_side == self.ylabel_side == 'right': heatmap_label_max_width = sum([max_yticklabel_w, ylabel_w]) else: heatmap_label_max_width = max([max_yticklabel_w, ylabel_w]) # heatmap_label_max_height = max([label.get_window_extent().height for label in self.yticklabels]) if len( # self.yticklabels) > 0 else 0 if ( heatmap_label_max_width >= self.label_max_width or self.legend_anchor == "ax_heatmap" ): self.label_max_width = heatmap_label_max_width # * 1.1 if len(self.legend_list) > 1: self.legend_list = sorted(self.legend_list, key=lambda x: x[3])
[docs] def plot_legends(self, ax=None): if self.verbose >= 1: print("Plotting legends..") if len(self.legend_list) > 0: if self.legend_side == "right" and not self.right_annotation is None: space = self.label_max_width elif ( self.legend_side == "right" and self.show_rownames and self.row_names_side == "right" ): space = self.label_max_width else: space = 0 # if self.right_annotation: # space+=sum(self.right_widths) legend_hpad = self.legend_hpad * mm2inch * self.legend_axes, self.cbars, self.boundry = plot_legend_list( self.legend_list, ax=ax, space=space + legend_hpad, legend_side=self.legend_side, gap=self.legend_gap, delta_x=self.legend_delta_x, legend_width=self.legend_width, legend_vpad=self.legend_vpad, verbose=self.verbose )
[docs] def plot(self, ax=None, subplot_spec=None, row_order=None, col_order=None): if self.verbose >= 1: print("Starting plotting..") if ax is None: = plt.gca() else: = ax self._define_gs_ratio() self._define_axes(subplot_spec) self._define_top_axes() self._define_left_axes() self._define_bottom_axes() self._define_right_axes() if row_order is None: if self.verbose >= 1: print("Starting calculating row orders..") self._reorder_rows() row_order = self.row_order if col_order is None: if self.verbose >= 1: print("Starting calculating col orders..") self._reorder_cols() col_order = self.col_order self.plot_matrix(row_order=row_order, col_order=col_order) if not self.top_annotation is None: gs =[0, 1] if not self.col_dendrogram else self.top_gs[1, 0] self.top_annotation._set_orentation("up") self.top_annotation.plot_annotations( ax=self.ax_top_annotation, subplot_spec=gs, idxs=col_order, wspace=self.wspace, ) if not self.bottom_annotation is None: self.bottom_annotation._set_orentation("down") self.bottom_annotation.plot_annotations( ax=self.ax_bottom_annotation,[2, 1], idxs=col_order, wspace=self.wspace, ) if not self.left_annotation is None: gs =[1, 0] if not self.row_dendrogram else self.left_gs[0, 1] self.left_annotation._set_orentation("left") self.left_annotation.plot_annotations( ax=self.ax_left_annotation, subplot_spec=gs, idxs=row_order, hspace=self.hspace, ) if not self.right_annotation is None: self.right_annotation._set_orentation("right") self.right_annotation.plot_annotations( ax=self.ax_right_annotation,[1, 2], idxs=row_order, hspace=self.hspace, ) if self.row_cluster or self.col_cluster: if self.row_dendrogram or self.col_dendrogram: self.plot_dendrograms(row_order, col_order) self.set_axes_labels_kws() self.set_xy_labels() self.collect_legends() # _draw_figure(self.ax_heatmap.figure) return
[docs] def tight_layout(self, **tight_params): tight_params = ( dict(h_pad=0.02, w_pad=0.02) if tight_params is None else tight_params ) left = 0 right = 1 if self.legend and self.legend_side == "right": right = self.boundry elif self.legend and self.legend_side == "left": left = self.boundry tight_params.setdefault("rect", [left, 0, right, 1])**tight_params)
[docs] def set_height(self, fig, height): matplotlib.figure.Figure.set_figheight(fig, height) # convert mm to inches
[docs] def set_width(self, fig, width): matplotlib.figure.Figure.set_figwidth(fig, width) # convert mm to inches
[docs] def post_processing(self): pass
[docs]def composite( cmlist=None, main=0, ax=None, axis=1, row_gap=15, col_gap=15, legend_side="right", legend_gap=5, legend_y=0.8, legend_hpad=None, legend_width=None, width_ratios=None, height_ratios=None, verbose=1 ): """ Assemble multiple ClusterMapPlotter objects vertically or horizontally together. Parameters ---------- cmlist: list a list of ClusterMapPlotter (with plot=False). axis: int 1 for columns (align the cmlist horizontally), 0 for rows (vertically). main: int use which as main ClusterMapPlotter, will influence row/col order. main is the index of cmlist. row/col_gap: float the row or columns gap between subplots, unit is mm [15]. legend_side: str right,left [right]. legend_gap: float row gap between two legends, unit is mm. legend_width: float default is None, will be estimated automatically width_ratios: list a list of width, values can be either float or int. height_ratios: list a list of height, values can be either float or int. Returns ------- tuple: ax,legend_axes """ if ax is None: ax = plt.gca() n = len(cmlist) wspace, hspace = 0, 0 if axis == 1: # horizontally wspace = col_gap * mm2inch * ax.figure.dpi / (ax.get_window_extent().width / n) nrows = 1 ncols = n width_ratios = ( [cm.data2d.shape[1] for cm in cmlist] if width_ratios is None else width_ratios ) height_ratios = None else: # vertically hspace = row_gap * mm2inch * ax.figure.dpi / (ax.get_window_extent().height / n) nrows = n ncols = 1 width_ratios = None height_ratios = ( [cm.data2d.shape[0] for cm in cmlist] if height_ratios is None else height_ratios ) gs = ax.figure.add_gridspec( nrows, ncols, width_ratios=width_ratios, height_ratios=height_ratios, wspace=wspace, hspace=hspace, ) axes = [] for i, cm in enumerate(cmlist): sharex = axes[0] if axis == 0 and i > 0 else None sharey = axes[0] if axis == 1 and i > 0 else None gs1 = gs[i, 0] if axis == 0 else gs[0, i] ax1 = ax.figure.add_subplot(gs1, sharex=sharex, sharey=sharey) ax1.set_axis_off() axes.append(ax1) cm_1 = cmlist[main] ax1 = axes[main] gs1 = gs[main, 0] if axis == 0 else gs[0, main] cm_1.plot(ax=ax1, subplot_spec=gs1, row_order=None, col_order=None) legend_list = cm_1.legend_list legend_names = [L[1] for L in legend_list] label_max_width = ( ax.figure.get_window_extent().width * cm_1.label_max_width / ) for i, cm in enumerate(cmlist): if i == main: continue gs1 = gs[i, 0] if axis == 0 else gs[0, i] if axis == 1: # composite horizontally, have the same row order col_order = None row_order = cm_1.row_order else: # vertically, have the same col order row_order = None col_order = cm_1.col_order cm.plot(ax=axes[i], subplot_spec=gs1, row_order=row_order, col_order=col_order) for L in cm.legend_list: if L[1] not in legend_names: legend_names.append(L[1]) legend_list.append(L) w = ( ax.figure.get_window_extent().width * cm.label_max_width / ) if w > label_max_width: label_max_width = w if len(legend_list) == 0: return None legend_list = sorted(legend_list, key=lambda x: x[3]) if legend_hpad is None: space = col_gap * mm2inch * ax.figure.dpi + label_max_width else: space = legend_hpad * ax.figure.dpi / 72 legend_axes, cbars, boundry = plot_legend_list( legend_list, ax=ax, space=space, legend_side=legend_side, gap=legend_gap, y0=legend_y, legend_width=legend_width, verbose=verbose ) ax.set_axis_off() # import pdb; # pdb.set_trace() return ax, legend_axes
