# -*- coding: utf-8 -*-
# !/usr/bin/env python3
import os, sys
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pylab as plt
from .utils import mm2inch
from .utils import (
_calculate_luminance,
cluster_labels,
plot_legend_list,
define_cmap,
get_colormap,
)
from .clustermap import plot_heatmap, heatmap
# -----------------------------------------------------------------------------
[docs]class AnnotationBase:
"""
Base class for annotation objects.
Parameters
----------
df: dataframe
a pandas series or dataframe (only one column).
cmap: str
colormap, such as Set1, Dark2, bwr, Reds, jet, hsv, rainbow and so on. Please see
https://matplotlib.org/3.5.0/tutorials/colors/colormaps.html for more information, or run
matplotlib.pyplot.colormaps() to see all availabel cmap.
default cmap is 'auto', it would be determined based on the dtype for each columns of df.
colors: dict, list or str
a dict or list (for boxplot, barplot) or str.
If colors is a dict: keys should be exactly the same as df.iloc[:,0].unique(),
values for the dict should be colors (color names or HEX color).
If colors is a list, then the length of this list should be equal to df.iloc[:,0].nunique()
If colors is a string, means all values in df.iloc[:,0].unique() share the same color.
height: float
height (if axis=1) / width (if axis=0) for the annotation size.
legend: bool
whether to plot legend for this annotation when legends are plotted or
plot legend with HeatmapAnnotation.plot_legends().
legend_kws: dict
vmax, vmin and other kws passed to plt.legend, such as title, prop, fontsize, labelcolor,
markscale, frameon, framealpha, fancybox, shadow, facecolor, edgecolor, mode and so on, for more
arguments, pleast type ?plt.legend. There is an additional parameter `color_text` (default is True),
which would set the color of the text to the same color as legend marker. if one set
`legend_kws={'color_text':False}`, then, black would be the default color for the text.
If the user want to use a custom color instead of black (such as blue), please set
legend_kws={'color_text':False,'labelcolor':'blue'}.
plot_kws: dict
other plot kws passed to annotation.plot, such as rotation, rotation_mode, ha, va,
annotation_clip, arrowprops and matplotlib.text.Text for anno_label. For example, in anno_simple,
there is also kws: vmin and vmax, if one want to change the range, please try:
anno_simple(df_box.Gene1,vmin=0,vmax=1,legend_kws={'vmin':0,'vmax':1}).
Returns
----------
Class AnnotationBase.
"""
def __init__(
self,
df=None,
cmap="auto",
colors=None,
height=None,
legend=None,
legend_kws=None,
**plot_kws
):
self._check_df(df)
self.label = None
self.ylim = None
self.color_dict = None
self.nrows = self._n_rows()
self.ncols = self._n_cols()
self.height = self._height(height)
self._type_specific_params()
self.legend = legend
self.legend_kws = legend_kws if not legend_kws is None else {}
self._set_default_plot_kws(plot_kws)
if colors is None:
self._check_cmap(cmap)
self._calculate_colors() # modify self.plot_data, self.color_dict (each col is a dict)
else:
self._check_colors(colors)
self._calculate_cmap() # modify self.plot_data, self.color_dict (each col is a dict)
self.plot_data = self.df.copy()
def _check_df(self, df):
if isinstance(df, pd.Series):
df = df.to_frame()
if isinstance(df, pd.DataFrame):
self.df = df
else:
raise TypeError("df must be a pandas DataFrame or Series.")
def _n_rows(self):
return self.df.shape[0]
def _n_cols(self):
return self.df.shape[1]
def _height(self, height):
return 3 * self.ncols if height is None else height
def _set_default_plot_kws(self, plot_kws):
self.plot_kws = {} if plot_kws is None else plot_kws
self.plot_kws.setdefault("zorder", 10)
[docs] def set_orientation(self, orientation):
self.orientation = orientation
[docs] def update_plot_kws(self, plot_kws):
self.plot_kws.update(plot_kws)
[docs] def set_label(self, label):
self.label = label
[docs] def set_legend(self, legend):
if self.legend is None:
self.legend = legend
[docs] def set_axes_kws(self, subplot_ax):
# ax.set_xticks(ticks=np.arange(1, self.nrows + 1, 1), labels=self.plot_data.index.tolist())
if self.axis == 1:
if self.ticklabels_side == "left":
subplot_ax.yaxis.tick_left()
elif self.ticklabels_side == "right":
subplot_ax.yaxis.tick_right()
subplot_ax.yaxis.set_label_position(self.label_side)
subplot_ax.yaxis.label.update(self.label_kws)
# ax.yaxis.labelpad = self.labelpad
subplot_ax.xaxis.set_visible(False)
subplot_ax.yaxis.label.set_visible(False)
else: # axis=0, row annotation
if self.ticklabels_side == "top":
subplot_ax.xaxis.tick_top()
elif self.ticklabels_side == "bottom":
subplot_ax.xaxis.tick_bottom()
subplot_ax.xaxis.set_label_position(self.label_side)
subplot_ax.xaxis.label.update(self.label_kws)
subplot_ax.xaxis.set_tick_params(self.ticklabels_kws)
# ax.yaxis.labelpad = self.labelpad
subplot_ax.yaxis.set_visible(False)
subplot_ax.xaxis.label.set_visible(False)
def _check_cmap(self, cmap):
if cmap == "auto":
col = self.df.columns.tolist()[0]
if self.df.dtypes[col] == object:
if self.df[col].nunique() <= 10:
self.cmap = "Set1"
elif self.df[col].nunique() <= 20:
self.cmap = "tab20"
else:
self.cmap = "random50"
elif self.df.dtypes[col] == float or self.df.dtypes[col] == int:
self.cmap = "jet"
else:
raise TypeError(
"Can not assign cmap for column %s, please specify cmap" % col
)
elif type(cmap) == str:
self.cmap = cmap
else:
print("WARNING: cmap is not a string!")
self.cmap = cmap
if (
get_colormap(self.cmap).N == 256
): # then heatmap will automatically calculate vmin and vmax
try:
self.plot_kws.setdefault("vmax", np.nanmax(self.df.values))
self.plot_kws.setdefault("vmin", np.nanmin(self.df.values))
except:
pass
def _calculate_colors(self): # add self.color_dict (each col is a dict)
self.color_dict = {}
col = self.df.columns.tolist()[0]
if get_colormap(self.cmap).N < 256 or self.df.dtypes[col] == object:
cc_list = (
self.df[col].value_counts().index.tolist()
) # sorted by value counts
self.df[col] = self.df[col].map({v: cc_list.index(v) for v in cc_list})
for v in cc_list:
color = get_colormap(self.cmap)(cc_list.index(v))
self.color_dict[v] = color # matplotlib.colors.to_hex(color)
else: # float
self.color_dict = {
v: get_colormap(self.cmap)(v) for v in self.df[col].values
}
self.colors = None
def _check_colors(self, colors):
assert isinstance(colors,(str,list,dict,tuple))
if isinstance(colors, str):
color_dict = {label: colors for label in self.df.iloc[:, 0].unique()}
elif isinstance(colors, (list,tuple)):
assert len(colors) == self.df.iloc[:, 0].nunique()
color_dict = {
label: color
for label, color in zip(self.df.iloc[:, 0].unique(), colors)
}
else:
color_dict=colors.copy()
if len(color_dict) >= self.df.iloc[:, 0].nunique():
self.colors = color_dict
else:
raise TypeError(
"The length of `colors` is not consistent with the shape of the input data"
)
def _calculate_cmap(self):
self.color_dict = self.colors
col = self.df.columns.tolist()[0]
cc_list = list(self.color_dict.keys()) # column values
self.df[col] = self.df[col].map({v: cc_list.index(v) for v in cc_list})
self.cmap = matplotlib.colors.ListedColormap([self.color_dict[k] for k in cc_list])
self.plot_kws.setdefault("vmax", get_colormap(self.cmap).N - 1)
self.plot_kws.setdefault("vmin", 0)
def _type_specific_params(self):
if self.ylim is None:
Max = np.nanmax(self.df.values)
Min = np.nanmin(self.df.values)
gap = Max - Min
self.ylim = [Min - 0.05 * gap, Max + 0.05 * gap]
[docs] def reorder(self, idx):
# Before plotting, df needs to be reordered according to the new clustered order.
# n_overlap = len(set(self.df.index.tolist()) & set(idx))
# if n_overlap == 0:
# raise ValueError("The input idx is not consistent with the df.index")
# else:
self.plot_data = self.df.reindex(idx) #
self.plot_data.fillna(np.nan, inplace=True)
self.nrows = self.plot_data.shape[0]
# self._set_default_plot_kws(self.plot_kws)
[docs] def get_label_width(self):
return self.ax.yaxis.label.get_window_extent(
renderer=self.ax.figure.canvas.get_renderer()
).width
[docs] def get_ticklabel_width(self):
yticklabels = self.ax.yaxis.get_ticklabels()
if len(yticklabels) == 0:
return 0
else:
return max(
[
label.get_window_extent(
renderer=self.ax.figure.canvas.get_renderer()
).width
for label in self.ax.yaxis.get_ticklabels()
]
)
[docs] def get_max_label_width(self):
return max([self.get_label_width(), self.get_ticklabel_width()])
# =============================================================================
[docs]class anno_simple(AnnotationBase):
"""
Annotate simple annotation, categorical or continuous variables.
"""
def __init__(
self,
df=None,
cmap="auto",
colors=None,
add_text=False,
majority=True,
text_kws=None,
height=None,
legend=True,
legend_kws=None,
**plot_kws
):
self.add_text = add_text
self.majority = majority
self.text_kws = text_kws if not text_kws is None else {}
self.plot_kws = plot_kws
# print(self.plot_kws)
legend_kws = legend_kws if not legend_kws is None else {}
if "vmax" in plot_kws:
legend_kws.setdefault("vmax", plot_kws.get("vmax"))
if "vmin" in plot_kws:
legend_kws.setdefault("vmin", plot_kws.get("vmin"))
super().__init__(
df=df,
cmap=cmap,
colors=colors,
height=height,
legend=legend,
legend_kws=legend_kws,
**plot_kws
)
def _set_default_plot_kws(self, plot_kws):
self.plot_kws = {} if plot_kws is None else plot_kws
self.plot_kws.setdefault("zorder", 10)
self.text_kws.setdefault("zorder", 16)
self.text_kws.setdefault("ha", "center")
self.text_kws.setdefault("va", "center")
def _calculate_colors(self): # add self.color_dict (each col is a dict)
self.color_dict = {}
col = self.df.columns.tolist()[0]
if get_colormap(self.cmap).N < 256:
cc_list = (
self.df[col].value_counts().index.tolist()
) # sorted by value counts
for v in cc_list:
color = get_colormap(self.cmap)(cc_list.index(v))
self.color_dict[v] = color # matplotlib.colors.to_hex(color)
else: # float
cc_list = None
self.color_dict = {
v: get_colormap(self.cmap)(v) for v in self.df[col].values
}
self.cc_list = cc_list
self.colors = None
def _calculate_cmap(self):
self.color_dict = self.colors
col = self.df.columns.tolist()[0]
cc_list = list(self.color_dict.keys()) # column values
self.cc_list = cc_list
self.cmap = matplotlib.colors.ListedColormap(
[self.color_dict[k] for k in cc_list]
)
def _type_specific_params(self):
pass
[docs] def plot(self, ax=None, axis=1):
if hasattr(self.cmap, "N"):
vmax = self.cmap.N - 1
elif type(self.cmap) == str:
vmax = get_colormap(self.cmap).N - 1
else:
vmax = len(self.color_dict) - 1
self.plot_kws.setdefault("vmax", vmax) # get_colormap(self.cmap).N
self.plot_kws.setdefault("vmin", 0)
if self.cc_list:
mat = (
self.plot_data.iloc[:, 0]
.map({v: self.cc_list.index(v) for v in self.cc_list})
.values
)
else:
mat = self.plot_data.values
matrix = mat.reshape(1, -1) if axis == 1 else mat.reshape(-1, 1)
ax1 = plot_heatmap(
matrix,
cmap=self.cmap,
ax=ax,
xticklabels=False,
yticklabels=False,
**self.plot_kws
) #y will be inverted inside plot_heatmap
ax.tick_params(
axis="both",
which="both",
left=False,
right=False,
top=False,
bottom=False,
labeltop=False,
labelbottom=False,
labelleft=False,
labelright=False,
)
if self.add_text:
if axis == 0:
self.text_kws.setdefault("rotation", 90)
self.text_kws.setdefault("rotation_mode", 'anchor')
labels, ticks = cluster_labels(
self.plot_data.iloc[:, 0].values,
np.arange(0.5, self.nrows, 1),
self.majority,
)
n = len(ticks)
if axis == 1:
x = ticks
y = [0.5] * n
else:
y = ticks
x = [0.5] * n
s = (
ax.get_window_extent().height
if axis == 1
else ax.get_window_extent().width
)
self.text_kws.setdefault("fontsize", 72 * s * 0.8 / ax.figure.dpi)
# fontsize = self.text_kws.pop('fontsize', 72 * s * 0.8 / ax.figure.dpi)
color = self.text_kws.pop("color", None)
for x0, y0, t in zip(x, y, labels):
# print(t,self.color_dict)
lum = _calculate_luminance(self.color_dict[t])
if color is None:
text_color = "black" if lum > 0.408 else "white"
else:
text_color = color
# print(t,self.color_dict,text_color,color)
self.text_kws.setdefault("color", text_color)
ax.text(x0, y0, t, **self.text_kws)
self.ax = ax
self.fig = self.ax.figure
return self.ax
# =============================================================================
[docs]class anno_label(AnnotationBase):
"""
Add label and text annotations. See example on documentatin website:
https://dingwb.github.io/PyComplexHeatmap/build/html/notebooks/single_cell_methylation.html
Parameters
----------
merge: bool
whether to merge the same clusters into one and label only once.
extend: bool
whether to distribute all the labels extend to the all axis, figure or ax or False.
frac: float
fraction of the armA and armB.
majority: bool
If there are multiple group for one label, whether to annotate the label in the largest group. [True]
adjust_color: bool
When the luminance of the color is too high, use black color replace the original color. [True]
luminance: float
luminance values [0-1], used together with adjust_color, when the calculated luminance > luminance,
the color will be replaced with black. [0.5]
relpos: tuple
relpos passed to arrowprops in plt.annotate, tuple (x,y) means the arrow start point position relative to the
label. default is (0, 0) if self.orientation == 'top' else (0, 1) for columns labels, (1, 1) if self.orientation == 'left'
else (0, 0) for rows labels.
plot_kws: dict
passed to plt.annotate, including annotation_clip, arrowprops and matplotlib.text.Text,
more information about arrowprops could be found in
matplotlib.patches.FancyArrowPatch. For example, to remove arrow, just set
arrowprops = dict(visible=False,)
Returns
----------
Class AnnotationBase.
"""
def __init__(
self,
df=None,
cmap="auto",
colors=None,
merge=False,
extend=False,
frac=0.2,
majority=True,
adjust_color=True,
luminance=0.8,
height=None,
legend=False,
legend_kws=None,
relpos=None,
**plot_kws
):
super().__init__(
df=df,
cmap=cmap,
colors=colors,
height=height,
legend=legend,
legend_kws=legend_kws,
**plot_kws
)
self.merge = merge
self.majority = majority
self.adjust_color = adjust_color
self.luminance = luminance
self.extend = extend
self.frac = frac
self.relpos = relpos
self.annotated_texts = []
def _height(self, height):
return 4 if height is None else height
[docs] def set_plot_kws(self, axis):
shrink = 1 # 1 * mm2inch * 72 # 1mm -> points
if axis == 1: # columns
relpos = (
(0, 0) if self.orientation == "up" else (0, 1)
) # position to anchor, x: left -> right, y: down -> top
rotation = 90 if self.orientation == "up" else -90
ha = "left"
va = "center"
else:
relpos = (
(1, 1) if self.orientation == "left" else (0, 0)
) # (1, 1) if self.orientation == 'left' else (0, 0)
rotation = 0
ha = "right" if self.orientation == "left" else "left"
va = "center"
# relpos: The exact starting point position of the arrow is defined by relpos. It's a tuple of relative
# coordinates of the text box, where (0, 0) is the lower left corner and (1, 1) is the upper right corner.
# Values <0 and >1 are supported and specify points outside the text box. By default (0.5, 0.5) the starting
# point is centered in the text box.
self.plot_kws.setdefault("rotation", rotation)
self.plot_kws.setdefault("ha", ha)
self.plot_kws.setdefault("va", va)
rp = relpos if self.relpos is None else self.relpos
arrowprops = dict(
arrowstyle="-",
color="black",
shrinkA=shrink,
shrinkB=shrink,
relpos=rp,
patchA=None,
patchB=None,
connectionstyle=None,
linewidth=0.5
)
# arrow: ->, from text to point.
# self.plot_kws.setdefault('transform_rotates_text', False)
self.plot_kws.setdefault("arrowprops", {})
for k in arrowprops:
if k not in self.plot_kws['arrowprops']:
self.plot_kws['arrowprops'][k]=arrowprops[k]
self.plot_kws.setdefault("rotation_mode", "anchor")
def _calculate_colors(self): # add self.color_dict (each col is a dict)
self.color_dict = {}
col = self.df.columns.tolist()[0]
if get_colormap(self.cmap).N < 256 or self.df.dtypes[col] == object:
cc_list = (
self.df[col].value_counts().index.tolist()
) # sorted by value counts
for v in cc_list:
color = get_colormap(self.cmap)(cc_list.index(v))
self.color_dict[v] = color # matplotlib.colors.to_hex(color)
else: # float
self.color_dict = {
v: get_colormap(self.cmap)(v) for v in self.df[col].values
}
self.colors = None
def _calculate_cmap(self):
self.color_dict = self.colors
col = self.df.columns.tolist()[0]
cc_list = list(self.color_dict.keys()) # column values
self.cmap = matplotlib.colors.ListedColormap(
[self.color_dict[k] for k in cc_list]
)
def _type_specific_params(self):
pass
[docs] def plot(self, ax=None, axis=1): # add self.gs,self.fig,self.ax,self.axes
self.axis = axis
if self.orientation is None:
ax_index = ax.figure.axes.index(ax)
ax_n = len(ax.figure.axes)
i = ax_index / ax_n
if axis == 1 and i <= 0.5:
orientation = "up"
elif axis == 1:
orientation = "down"
elif axis == 0 and i <= 0.5:
orientation = "left"
else:
orientation = "right"
self.orientation = orientation
self.set_plot_kws(axis)
if (
self.merge
): # merge the adjacent ticklabels with the same text to one, return labels and mean x coordinates.
labels, ticks = cluster_labels(
self.plot_data.iloc[:, 0].values,
np.arange(0.5, self.nrows, 1),
self.majority,
)
else:
labels = self.plot_data.iloc[:, 0].values
ticks = np.arange(0.5, self.nrows, 1)
# labels are the merged labels, ticks are the merged mean x coordinates.
n = len(ticks)
arrow_height = self.height * mm2inch * ax.figure.dpi # convert height (mm) to inch and to pixels.
text_y = arrow_height
if axis == 1:
if self.orientation == "down":
# ax.invert_yaxis() # top -> bottom
text_y = -1 * arrow_height
ax.set_xticks(ticks=np.arange(0.5, self.nrows, 1))
x = ticks # coordinate for the arrow start point
y = [0] * n if self.orientation == "up" else [1] * n # position for line start on axes
if self.extend:
extend_pos = np.linspace(0, 1, n + 1) #0,0.1,0.2,...0.9,1
x1 = [(extend_pos[i] + extend_pos[i - 1]) / 2 for i in range(1, n + 1)] #coordinates for text: 0.05,0.15..
y1 = [1] * n if self.orientation == "up" else [0] * n
else:
x1 = [0] * n #offset pixels
y1 = [text_y] * n #offset pixels
else:
if self.orientation == "left":
# ax.invert_xaxis() # right -> left, will not affect ax.get_xaxis_transform()
text_y = -1 * arrow_height
ax.set_yticks(ticks=np.arange(0.5, self.nrows, 1))
y=ticks
x = [1] * n if self.orientation == "left" else [0] * n #coordinate for start point, side=left, x axis <---
if self.extend: #ax.transAxes
# extend_pos = np.linspace(0, 1, n + 1)
extend_pos = np.linspace(1,0, n + 1) #y, top -> bottom
y1 = [(extend_pos[i] + extend_pos[i - 1]) / 2 for i in range(1, n + 1)]
x1 = [1] * n if self.orientation == "right" else [0] * n
else: #offset pixels
y1 = [0] * n #vertical distance related to point (anno_simple)
x1 = [text_y] * n #horizonal distance related to point (anno_simple)
# angleA is the angle for the data point (clockwise), B is for text.
# https://matplotlib.org/stable/gallery/userdemo/connectionstyle_demo.html
xycoords = ax.get_xaxis_transform() if axis == 1 else ax.get_yaxis_transform()
# get_xaxis_transform: x is data coordinates,y is between [0,1], will not be affected by invert_xaxis()
if self.extend:
text_xycoords = ax.transAxes #relative coordinates
else:
text_xycoords = "offset pixels"
if self.plot_kws["arrowprops"]["connectionstyle"] is None:
arm_height = arrow_height * self.frac
rad = 2 # arm_height / 10
if axis == 1 and self.orientation == "up":
angleA, angleB = (self.plot_kws["rotation"] - 180, 90)
elif axis == 1 and self.orientation == "down":
angleA, angleB = (180 + self.plot_kws["rotation"], -90)
elif axis == 0 and self.orientation == "left":
angleA, angleB = (self.plot_kws["rotation"], -180)
else:
angleA, angleB = (self.plot_kws["rotation"] - 180, 0)
connectionstyle = f"arc,angleA={angleA},angleB={angleB},armA={arm_height},armB={arm_height},rad={rad}"
self.plot_kws["arrowprops"]["connectionstyle"] = connectionstyle
# import pdb;
# pdb.set_trace()
for t, x_0, y_0, x_1, y_1 in zip(labels, x, y, x1, y1):
if pd.isna(t):
continue
color = self.color_dict[t]
if self.adjust_color:
lum = _calculate_luminance(color)
if lum > self.luminance:
color = "black"
self.plot_kws["arrowprops"]["color"] = color
annotated_text = ax.annotate(
text=t,
xy=(x_0, y_0), #The point (x, y) to annotate
xytext=(x_1, y_1), #The position (x, y) to place the text at. The coordinate system is determined by textcoords.
xycoords=xycoords,
textcoords=text_xycoords,
color=color,
**self.plot_kws
) # unit for shrinkA is point (1 point = 1/72 inches)
self.annotated_texts.append(annotated_text)
ax.set_axis_off()
self.ax = ax
self.fig = self.ax.figure
return self.ax
[docs] def get_ticklabel_width(self):
hs = [text.get_window_extent().width for text in self.annotated_texts]
if len(hs) == 0:
return 0
else:
return max(hs)
# =============================================================================
[docs]class anno_boxplot(AnnotationBase):
"""
annotate boxplots, all arguments are included in AnnotationBase,
plot_kws for anno_boxplot include showfliers, edgecolor, grid, medianlinecolor
width,zorder and other arguments passed to plt.boxplot.
Parameters
----------
"""
def _height(self, height):
return 10 if height is None else height
def _set_default_plot_kws(self, plot_kws):
self.plot_kws = plot_kws if plot_kws is not None else {}
self.plot_kws.setdefault("showfliers", False)
self.plot_kws.setdefault("edgecolor", "black")
self.plot_kws.setdefault("medianlinecolor", "black")
self.plot_kws.setdefault("grid", False)
self.plot_kws.setdefault("zorder", 10)
self.plot_kws.setdefault("widths", 0.5)
def _check_cmap(self, cmap):
if cmap == "auto":
self.cmap = "jet"
elif type(cmap) == str:
self.cmap = cmap
else:
print("WARNING: cmap for boxplot is not a string!")
self.cmap = cmap
def _calculate_colors(self): # add self.color_dict (each col is a dict)
self.colors = None
def _check_colors(self, colors):
if type(colors) == str:
self.colors = colors
else:
raise TypeError(
"Boxplot only support one string as colors now, if more colors are wanted, cmap can be specified."
)
def _calculate_cmap(self):
self.set_legend(False)
self.cmap = None
[docs] def plot(self, ax=None, axis=1): # add self.gs,self.fig,self.ax,self.axes
fig = ax.figure
if self.colors is None: # calculate colors based on cmap
colors = [
get_colormap(self.cmap)(self.plot_data.loc[sampleID].mean())
for sampleID in self.plot_data.index.values
]
else:
colors = [self.colors] * self.plot_data.shape[0] # self.colors is a string
# print(self.plot_kws)
plot_kws = self.plot_kws.copy()
edgecolor = plot_kws.pop("edgecolor")
mlinecolor = plot_kws.pop("medianlinecolor")
grid = plot_kws.pop("grid")
# bp = ax.boxplot(self.plot_data.T.values, patch_artist=True,**self.plot_kws)
if axis == 1:
vert = True
ax.set_xticks(ticks=np.arange(0.5, self.nrows, 1))
else:
vert = False
ax.set_yticks(ticks=np.arange(0.5, self.nrows, 1))
# bp = self.plot_data.T.boxplot(ax=ax, patch_artist=True,vert=vert,return_type='dict',**self.plot_kws)
bp = ax.boxplot(
x=self.plot_data.T.values,
positions=np.arange(0.5, self.nrows, 1),
patch_artist=True,
vert=vert, #If True, draws vertical boxes. If False, draw horizontal boxes
**plot_kws
)
if grid:
ax.grid(linestyle="--", zorder=-10)
for box, color in zip(bp["boxes"], colors):
box.set_facecolor(color)
box.set_edgecolor(edgecolor)
for median_line in bp["medians"]:
median_line.set_color(mlinecolor)
if axis == 1:
ax.set_xlim(0, self.nrows)
ax.set_ylim(*self.ylim)
ax.tick_params(
axis="both",
which="both",
top=False,
bottom=False,
labeltop=False,
labelbottom=False,
)
else:
ax.set_ylim(0, self.nrows)
ax.set_xlim(*self.ylim)
ax.tick_params(
axis="both",
which="both",
left=False,
right=False,
labelleft=False,
labelright=False,
)
# if self.orientation=='left':
# ax.invert_xaxis()
self.fig = fig
self.ax = ax
return self.ax
# =============================================================================
[docs]class anno_barplot(anno_boxplot):
"""
Annotate barplot, all arguments are included in AnnotationBase,
plot_kws for anno_boxplot include edgecolor, grid,align,zorder,
and other arguments passed to plt.barplot.
"""
def _set_default_plot_kws(self, plot_kws):
self.plot_kws = plot_kws if plot_kws is not None else {}
self.plot_kws.setdefault("edgecolor", "black")
self.plot_kws.setdefault("grid", False)
self.plot_kws.setdefault("zorder", 10)
# self.plot_kws.setdefault('width', 0.7)
self.plot_kws.setdefault("align", "center")
def _check_cmap(self, cmap):
if cmap == "auto":
if self.ncols == 1:
self.cmap = "jet"
else:
self.cmap = "Set1"
# print(cmap,self.cmap)
else:
self.cmap = cmap
if self.ncols >= 2 and get_colormap(self.cmap).N >= 256:
raise TypeError(
"cmap for stacked barplot should not be continuous, you should try: Set1, Dark2 and so on."
)
def _calculate_colors(self): # add self.color_dict (each col is a dict)
col_list = self.df.columns.tolist()
self.color_dict = {}
if self.ncols >= 2: # more than two columns, colored by columns names
self.colors = [
get_colormap(self.cmap)(col_list.index(v)) for v in self.df.columns
] #list
for v, color in zip(col_list, self.colors):
self.color_dict[v] = color
else: # only one column, colored by cols[0] values (float)
# vmax, vmin = np.nanmax(self.df[col_list[0]].values), np.nanmin(self.df[col_list[0]].values)
# delta = vmax - vmin
# values = self.df[col_list[0]].fillna(np.nan).unique()
self.cmap, normalize = define_cmap(
self.df[col_list[0]].fillna(np.nan).values,
vmin=None,
vmax=None,
cmap=self.cmap,
center=None,
robust=False,
na_col="white",
)
# self.colors = {v: matplotlib.colors.rgb2hex(get_colormap(self.cmap)((v - vmin) / delta)) for v in values}
self.colors = lambda v: matplotlib.colors.rgb2hex(
self.cmap(normalize(v))
) # a function
self.color_dict = None
def _check_colors(self, colors):
self.colors = colors
col_list = self.df.columns.tolist()
if not isinstance(colors, (list, str, dict, tuple)):
raise TypeError("colors must be list of string,list, tuple or dict")
if type(colors) == str:
color_dict = {label: colors for label in col_list}
elif isinstance(colors,(list,tuple)):
assert len(colors) == self.ncols, "length of colors should match length of df.columns"
color_dict = {
label: color
for label, color in zip(col_list, colors)
}
else:
assert isinstance(colors, dict)
color_dict=colors.copy()
keys=list(color_dict.keys())
for key in keys:
if key not in col_list:
del color_dict[key]
self.color_dict = color_dict
def _calculate_cmap(self):
self.cmap = None
# self.set_legend(False)
def _type_specific_params(self):
if self.ncols > 1:
self.stacked = True
else:
self.stacked = False
if self.ylim is None:
Max = np.nanmax(self.df.sum(axis=1).values) if self.stacked else np.nanmax(self.df.values)
Min = np.nanmin(self.df.sum(axis=1).values) if self.stacked else np.nanmin(self.df.values)
gap = Max - Min
self.ylim = [Min - 0.05 * gap, Max + 0.05 * gap]
[docs] def plot(self, ax=None, axis=1): # add self.gs,self.fig,self.ax,self.axes
if ax is None:
ax = plt.gca()
fig = ax.figure
plot_kws = self.plot_kws.copy()
grid = plot_kws.pop("grid", False)
if grid:
ax.grid(linestyle="--", zorder=-10)
if self.ncols ==1 and not self.cmap is None: # only one columns, use cmap
colors = [[self.colors(v) for v in self.plot_data.iloc[:, 0].values]]
else: # self.ncols ==1: #cmap is None,use color_dict
assert not self.color_dict is None
colors=[self.color_dict[col] for col in self.plot_data.columns]
base_coordinates = [0] * self.plot_data.shape[0]
for col, color in zip(self.plot_data.columns, colors):
if axis == 1: #columns annotations
ax.set_xticks(ticks=np.arange(0.5, self.nrows, 1))
ax.bar(
x=np.arange(0.5, self.nrows, 1),
height=self.plot_data[col].values,
bottom=base_coordinates,
color=color,
**plot_kws
)
ax.set_xlim(0, self.nrows)
ax.set_ylim(*self.ylim)
else:
ax.set_yticks(ticks=np.arange(0.5, self.nrows, 1))
ax.barh(
y=np.arange(0.5, self.nrows, 1),
width=self.plot_data[col].values,
left=base_coordinates,
color=color,
**plot_kws
)
ax.set_ylim(0, self.nrows)
ax.set_xlim(*self.ylim)
base_coordinates = self.plot_data[col].values + base_coordinates
# for patch in ax.patches:
# patch.set_edgecolor(edgecolor)
if axis == 0:
ax.tick_params(
axis="both",
which="both",
left=False,
right=False,
labelleft=False,
labelright=False,
)
# if self.orientation == 'left':
# ax.invert_xaxis()
else:
ax.tick_params(
axis="both",
which="both",
top=False,
bottom=False,
labeltop=False,
labelbottom=False,
)
self.fig = fig
self.ax = ax
return self.ax
# =============================================================================
[docs]class anno_scatterplot(anno_barplot):
"""
Annotate scatterplot, all arguments are included in AnnotationBase,
plot_kws for anno_scatterplot include linewidths, grid, edgecolors
and other arguments passed to plt.scatter.
"""
def _check_df(self, df):
if isinstance(df, pd.Series):
df = df.to_frame(name=df.name)
if isinstance(df, pd.DataFrame) and df.shape[1] != 1:
raise ValueError("df must have only 1 column for scatterplot")
elif isinstance(df, pd.DataFrame):
self.df = df
else:
raise TypeError("df must be a pandas DataFrame or Series.")
def _set_default_plot_kws(self, plot_kws):
self.plot_kws = plot_kws if plot_kws is not None else {}
self.plot_kws.setdefault("grid", False)
self.plot_kws.setdefault("zorder", 10)
self.plot_kws.setdefault("linewidths", 0)
self.plot_kws.setdefault("edgecolors", "black")
def _check_cmap(self, cmap):
self.cmap = "jet"
if cmap == "auto":
pass
elif type(cmap) == str:
self.cmap = cmap
else:
print("WARNING: cmap for scatterplot is not a string!")
self.cmap = cmap
def _calculate_colors(self): # add self.color_dict (each col is a dict)
self.colors = None
def _check_colors(self, colors):
if not isinstance(colors, str):
raise TypeError(
"colors must be string for scatterplot, if more colors are neded, please try cmap!"
)
self.colors = colors
def _calculate_cmap(self):
self.cmap = None
self.set_legend(False)
def _type_specific_params(self):
Max = np.nanmax(self.df.values)
Min = np.nanmin(self.df.values)
self.gap = Max - Min
if self.ylim is None:
self.ylim = [Min - 0.05 * self.gap, Max + 0.05 * self.gap]
[docs] def plot(self, ax=None, axis=1): # add self.gs,self.fig,self.ax,self.axes
if ax is None:
ax = plt.gca()
fig = ax.figure
plot_kws = self.plot_kws.copy()
grid = plot_kws.pop("grid", False)
if grid:
ax.grid(linestyle="--", zorder=-10)
values = self.plot_data.iloc[:, 0].values
if self.colors is None:
colors = values
else: # self.colors is a string
colors = [self.colors] * self.plot_data.shape[0]
if axis == 1:
spu = (
ax.get_window_extent().height * 72 / self.gap / fig.dpi
) # size per unit
else:
spu = (
ax.get_window_extent().width * 72 / self.gap / fig.dpi
) # size per unit
self.s = (values - values.min() + self.gap * 0.1) * spu # fontsize
if axis == 1:
ax.set_xticks(ticks=np.arange(0.5, self.nrows, 1))
x = np.arange(0.5, self.nrows, 1)
y = values
else:
ax.set_yticks(ticks=np.arange(0.5, self.nrows, 1))
y = np.arange(0.5, self.nrows, 1)
x = values
c = self.plot_kws.get("c", colors)
s = self.plot_kws.get("s", self.s)
scatter_ax = ax.scatter(x=x, y=y, c=c, s=s, cmap=self.cmap, **plot_kws)
if axis == 0: #row annotations
ax.set_ylim(0, self.nrows)
ax.set_xlim(*self.ylim)
ax.tick_params(
axis="both",
which="both",
left=False,
right=False,
labelleft=False,
labelright=False,
)
# if self.orientation == 'left':
# ax.invert_xaxis()
else: #columns annotations
ax.set_xlim(0, self.nrows)
ax.set_ylim(*self.ylim)
ax.tick_params(
axis="both",
which="both",
top=False,
bottom=False,
labeltop=False,
labelbottom=False,
)
self.fig = fig
self.ax = ax
return self.ax
[docs]class anno_img(AnnotationBase):
"""
Annotate images.
Parameters
----------
border_width : int
width of border lines between images (0-256?). Ignored when merge is True.
border_color : int
color of border lines. black:0, white:255. Ignored when merge is True.
merge: bool
whether to merge the same clusters into one and show image only once.
merge_width: float
width of image when merge is True
whether to merge the same clusters into one and show image only once.
rotate: int
Rotate the input images
mode: str
all possible mode to convert, between "L", "RGB" and "CMYK", 'RGBA', default is RGBA
"""
def __init__(
self,
df=None,
cmap=None,
colors=None,
border_width=1,
border_color=255,
merge=False,
merge_width=1,
rotate=None,
mode='RGBA',
**plot_kws
):
self.border_width = border_width
self.border_color = border_color
self.merge = merge
self.merge_width = merge_width
self.rotate=rotate
self.mode=mode
self.plot_kws = plot_kws
super().__init__(
df=df,
cmap=cmap,
colors=colors,
**plot_kws
)
def _height(self, height):
return 10 if height is None else height
def _set_default_plot_kws(self, plot_kws):
self.plot_kws = plot_kws if plot_kws is not None else {}
def _calculate_colors(self): # add self.color_dict (each col is a dict)
self.colors = None
def _check_cmap(self, cmap):
self.cmap = None
[docs] def read_img(self,img_path=None,shape=None):
#import matplotlib.image as mpimg # mpimg.imread
import PIL
import requests
from io import BytesIO
if pd.isna(img_path):
if shape is None:
return None
else:
new_shape=tuple([shape[1],shape[0]]+list(shape[2:]))
# print(shape, new_shape,type(shape), 'here')
return np.full(new_shape, self.border_color)
if os.path.exists(img_path):
image = PIL.Image.open(img_path) #mpimg.imread(img_path)
else: #remote file
response = requests.get(img_path)
# Open the image from bytes
image = PIL.Image.open(BytesIO(response.content))
if image.mode != self.mode:
image = image.convert(self.mode)
if not shape is None:
image=image.resize(shape[:2]) #width, height
if not self.rotate is None:
image=image.rotate(self.rotate)
# Convert the image to an array if needed
image = np.array(image)
return image
def _add_border(self, img, width=1, color=0, axis=1):
w = width
if axis == 1:
pad_width = ((0, 0), (w, w), (0, 0))
else:
pad_width = ((w, w), (0, 0), (0, 0))
bordered_img = np.pad(img, pad_width=pad_width,
mode='constant', constant_values=color)
return bordered_img
def _type_specific_params(self):
pass
[docs] def plot(self, ax=None, axis=1):
if ax is None:
ax = plt.gca()
if axis==1:
imgfiles = list(self.plot_data.iloc[:,0]) #[::-1] #fix bug for the inverted yaxis
else:
imgfiles = list(self.plot_data.iloc[:, 0])[::-1]
imgs = [self.read_img(img_path=imgfile) for imgfile in imgfiles]
shapes = [img.shape for img in imgs if not img is None] # (height,width, channel)
if len(set(shapes)) > 1 or len(shapes) != len(imgs): # None is in imgs
# resize the images to make sure all images have the same height and wdith
if len(shapes)>1:
shape = np.min(np.vstack(shapes), axis=0) # height,width, channel; height, width,*channel
else:
shape=shapes[0]
new_shape = tuple([shape[1], shape[0]] + list(shape[2:]))
imgs = [self.read_img(img_path=imgfile, shape=new_shape) for imgfile in imgfiles]
shapes = [img.shape for img in imgs]
# for img in imgs:
# print(img.shape)
img_shape = shapes[0]
img_h = img_shape[0] # shape: height,width, channel
img_w = img_shape[1]
if self.merge:
origin = 'upper'
assert self.plot_data.iloc[:,0].dropna().nunique()==1, "Not all file names in the list are identical"
imgs = imgs[0]
if axis==1: #columns annotation
extent = [self.nrows/2-self.merge_width/2, self.nrows/2+self.merge_width/2, 0, img_h]
# floats (left, right, bottom, top), optional
# The bounding box in data coordinates that the image will fill
else:
extent = [0, img_w, self.nrows/2-self.merge_width/2, self.nrows/2+self.merge_width/2]
else:
if axis==1:
imgs = np.hstack(tuple([self._add_border(img,width=self.border_width,
color=self.border_color, axis=axis) \
for img in imgs]))
extent = [0, self.nrows, 0, img_h]
origin='upper'
else: #axis=0
# ax.invert_yaxis() # y is shared, invert has no effect (only useful when anno_img on the most right side, main axes of sharey)
# in default, if orientation=='right', x direction is: left -> right, orient='left', right -> left
origin = 'lower'
if self.orientation=='left':
# ax.invert_xaxis() # no effect
ax.set_xlim(img_w,0)
# else:
# # ax.set_ylim(self.nrows,0)
imgs = np.vstack(tuple([self._add_border(img,
width=self.border_width, color=self.border_color, axis=axis) \
for img in imgs[::-1]])) #bottom -> up? to invert: up -> bottom
extent = [0,img_w, 0, self.nrows]
self.plot_kws.setdefault('origin',origin)
ax.imshow(imgs, aspect='auto', extent=extent, cmap=self.cmap, **self.plot_kws)
ax.tick_params(axis='both',which='both',labelbottom=False, labelleft=False,
labelright=False, labeltop=False,
bottom=False, left=False,
right=False, top=False)
# ax.set_axis_off()
self.ax = ax
self.fig = self.ax.figure
return self.ax
[docs]class anno_lineplot(anno_barplot):
"""
Annotate lineplot, all arguments are included in AnnotationBase,
parameter grid control whether to show grid (default is True),
other arguments passed to plt.plot, including linewidth, marker and so on.
"""
def _check_df(self, df):
if isinstance(df, pd.Series):
self.df = df.to_frame(name=df.name)
elif isinstance(df, pd.DataFrame):
self.df = df
else:
raise TypeError("df must be a pandas DataFrame or Series.")
def _set_default_plot_kws(self, plot_kws):
self.plot_kws = plot_kws if plot_kws is not None else {}
self.plot_kws.setdefault("grid", False)
self.plot_kws.setdefault("zorder", 10)
self.plot_kws.setdefault("linewidth", 1)
def _check_cmap(self, cmap):
self.cmap = "Set1"
if cmap == "auto":
pass
elif type(cmap) == str:
self.cmap = cmap
else:
print("WARNING: cmap for scatterplot is not a string!")
self.cmap = cmap
def _calculate_colors(self): # add self.color_dict (each col is a dict)
col_list = self.df.columns.tolist()
self.color_dict = {}
self.colors = [get_colormap(self.cmap)(col_list.index(v)) for v in col_list]
for v, color in zip(col_list, self.colors):
self.color_dict[v] = color
[docs] def plot(self, ax=None, axis=1): # add self.gs,self.fig,self.ax,self.axes
if ax is None:
ax = plt.gca()
fig = ax.figure
plot_kws = self.plot_kws.copy()
grid = plot_kws.pop("grid", False)
if grid:
ax.grid(linestyle="--", zorder=-10)
for col in self.color_dict:
color=self.color_dict[col]
if axis == 1:
ax.set_xticks(ticks=np.arange(0.5, self.nrows, 1))
ax.plot(
np.arange(0.5, self.nrows, 1),
self.plot_data[col].values,
color=color,
**plot_kws
)
ax.set_xlim(0, self.nrows)
ax.set_ylim(*self.ylim)
else:
ax.set_yticks(ticks=np.arange(0.5, self.nrows, 1))
ax.plot(
self.plot_data[col].values,
np.arange(0.5, self.nrows, 1),
color=color,
**plot_kws
)
ax.set_ylim(0, self.nrows)
ax.set_xlim(*self.ylim)
if axis == 0:
ax.tick_params(
axis="both",
which="both",
left=False,
right=False,
labelleft=False,
labelright=False,
)
# if self.orientation == 'left':
# ax.invert_xaxis()
else:
ax.tick_params(
axis="both",
which="both",
top=False,
bottom=False,
labeltop=False,
labelbottom=False,
)
# if self.orientation=='down':
# ax.invert_yaxis()
self.fig = fig
self.ax = ax
return self.ax
# =============================================================================
[docs]class HeatmapAnnotation:
"""
Generate and plot heatmap annotations.
Parameters
----------
self : Class
HeatmapAnnotation
df : dataframe
a pandas dataframe, each column will be converted to one anno_simple class.
axis : int
1 for columns annotation, 0 for rows annotations.
cmap : str
colormap, such as Set1, Dark2, bwr, Reds, jet, hsv, rainbow and so on. Please see
https://matplotlib.org/3.5.0/tutorials/colors/colormaps.html for more information, or run
matplotlib.pyplot.colormaps() to see all availabel cmap.
default cmap is 'auto', it would be determined based on the dtype for each columns in df.
If df is None, then there is no need to specify cmap, cmap and colors will only be used when
df is provided.
If cmap is a string, then all columns in the df would have the same cmap, cmap can also be
a dict, keys are the column names from df, values should be cmap (matplotlib.pyplot.colormaps()).
colors : dict
a dict, keys are the column names of df, values are list, dict or string passed to
AnnotationBase.__subclasses__(), including anno_simple, anno_boxplot,anno_label and anno_scatter.
colors must have the same length as the df.columns, if colors is not provided (default), else,
colors would be calculated based on the given cmap.
If colors is given, then the cmap would be invalid.
label_side : str
top or bottom when axis=1 (columns annotation), left or right when axis=0 (rows annotations).
label_kws : dict
kws passed to the labels of the annotation labels (would be df.columns if df is given).
such as alpha, color, fontsize, fontstyle, ha (horizontalalignment),
va (verticalalignment), rotation, rotation_mode, visible, rasterized and so on.
For more information, see plt.gca().yaxis.label.properties() or ax.yaxis.label.properties()
ticklabels_kws : dict
label_kws is for the label of annotation, ticklabels_kws is for the label (text) in anno_label,
such as axis, which, direction, length, width,
color, pad, labelsize, labelcolor, colors, zorder, bottom, top, left, right, labelbottom, labeltop,
labelleft, labelright, labelrotation, grid_color, grid_linestyle and so on.
For more information,see ?matplotlib.axes.Axes.tick_params
plot_kws : dict
kws passed to annotation functions, such as anno_simple, anno_label et.al.
plot : bool
whether to plot, when the annotation are included in clustermap, plot would be
set to False automotially.
legend : bool
True or False, or dict (when df is no None), when legend is dict, keys are the
columns of df.
legend_side : str
right or left
legend_gap : float
the vertical gap between two legends, default is 2 [mm]
legend_width: float
width of the legend, default is 4.5[mm]
legend_hpad: float
Horizonal space between heatmap and legend, default is 2 [mm].
legend_vpad: float
Vertical space between top of ax and legend, default is 2 [mm].
orientation: str
up or down, when axis=1
left or right, when axis=0;
When anno_label shows up in annotation, the orientation would be automatically be assigned according
to the position of anno_label.
wgap: float or int
optional, the space used to calculate wspace, default is [0.1] (mm),
control the vertical gap between two annotations.
hgap: float or int
optional, the space used to calculate hspace, default is [0.1] (mm),
control the horizontal gap between two annotations.
plot_legend : bool
whether to plot legends.
args : name-value pair
key is the annotation label (name), values can be a pandas dataframe,
series, or annotation such as
anno_simple, anno_boxplot, anno_scatter, anno_label, or anno_barplot.
Returns
-------
Class HeatmapAnnotation.
"""
def __init__(
self,
df=None,
axis=1,
cmap="auto",
colors=None,
label_side=None,
label_kws=None,
ticklabels_kws=None,
plot_kws=None,
plot=False,
legend=True,
legend_side="right",
legend_gap=5,
legend_width=4.5,
legend_hpad=2,
legend_vpad=5,
orientation=None,
wgap=0.1,
hgap=0.1,
plot_legend=True,
rasterized=False,
verbose=1,
**args
):
if df is None and len(args) == 0:
raise ValueError("Please specify either df or other args")
if not df is None and len(args) > 0:
raise ValueError("df and Name-value pairs can only be given one, not both.")
if not df is None:
self._check_df(df)
else:
self.df = None
self.axis = axis
self.verbose = verbose
self.label_side = label_side
self.plot_kws = plot_kws if not plot_kws is None else {}
self.args = args
self._check_legend(legend)
self.legend_side = legend_side
self.legend_gap = legend_gap
self.wgap = wgap
self.hgap = hgap
self.legend_width = legend_width
self.legend_hpad = legend_hpad
self.legend_vpad = legend_vpad
self.plot_legend = plot_legend
self.rasterized = rasterized
self.orientation = orientation
self.plot = plot
if colors is None:
self._check_cmap(cmap)
self.colors = None
else:
self._check_colors(colors)
self._process_data()
self._heights()
self._nrows()
self.label_kws, self.ticklabels_kws = label_kws, ticklabels_kws
if self.plot:
self.plot_annotations()
def _check_df(self, df):
if type(df) == list or isinstance(df, np.ndarray):
df = pd.Series(df).to_frame(name="df")
elif isinstance(df, pd.Series):
name = df.name if not df.name is None else "df"
df = df.to_frame(name=name)
if not isinstance(df, pd.DataFrame):
raise TypeError(
"data type of df could not be recognized, should be a dataframe"
)
self.df = df
def _check_legend(self, legend):
if type(legend) == bool:
if not self.df is None:
self.legend = {col: legend for col in self.df.columns}
if len(self.args) > 0:
# self.legend = collections.defaultdict(lambda: legend)
self.legend = {arg: legend for arg in self.args}
elif type(legend) == dict:
self.legend = legend
for arg in self.args:
if arg not in self.legend:
self.legend[arg] = False
else:
raise TypeError("Unknow data type for legend!")
def _check_cmap(self, cmap):
if self.df is None:
return
self.cmap = {}
if cmap == "auto":
for col in self.df.columns:
if self.df.dtypes[col] == object:
if self.df[col].nunique() <= 10:
self.cmap[col] = "Set1"
elif self.df[col].nunique() <= 20:
self.cmap[col] = "tab20"
else:
self.cmap[col] = "random50"
elif self.df.dtypes[col] == float or self.df.dtypes[col] == int:
self.cmap[col] = "jet"
else:
raise TypeError(
"Can not assign cmap for column %s, please specify cmap" % col
)
elif type(cmap) == str:
self.cmap = {col: cmap for col in self.df.columns}
elif type(cmap) == list:
if len(cmap) == 1:
cmap = cmap * len(self.df.shape[1])
if len(cmap) != self.df.shape[1]:
raise ValueError(
"kind must have the same lengt with the number of columns with df"
)
self.cmap = {col: c for col, c in zip(self.df.columns, cmap)}
elif type(cmap) == dict:
if len(cmap) != self.df.shape[1]:
raise ValueError(
"kind must have the same length with number of columns with df"
)
self.cmap = cmap
else:
print("WARNING: unknown datatype for cmap!")
self.cmap = cmap
def _check_colors(self, colors):
if self.df is None:
return
self.colors = {}
if not isinstance(colors, dict):
raise TypeError("colors must be a dict!")
if len(colors) != self.df.shape[1]:
raise ValueError("colors must have the same length as the df.columns!")
self.colors = colors
def _process_data(self): # add self.annotations,self.names,self.labels
self.annotations = []
self.plot_kws["rasterized"] = self.rasterized
if not self.df is None:
for col in self.df.columns:
plot_kws = self.plot_kws.copy()
if self.colors is None:
plot_kws.setdefault("cmap", self.cmap[col]) #
else:
plot_kws.setdefault("colors", self.colors[col])
anno1 = anno_simple(
self.df[col], legend=self.legend.get(col, False), **plot_kws
)
anno1.set_label(col)
anno1.set_orientation(self.orientation)
self.annotations.append(anno1)
elif len(self.args) > 0:
# print(self.args)
self.labels = []
for arg in self.args:
# print(arg)
ann = self.args[arg] # Series, anno_* or DataFrame
if type(ann) == list or isinstance(ann, np.ndarray):
ann = pd.Series(ann).to_frame(name=arg)
elif isinstance(ann, pd.Series):
ann = ann.to_frame(name=arg)
if isinstance(ann, pd.DataFrame):
if ann.shape[1] > 1:
for col in ann.columns:
anno1 = anno_simple(
ann[col],
legend=self.legend.get(col, False),
**self.plot_kws
)
anno1.set_label(col)
self.annotations.append(anno1)
else:
anno1 = anno_simple(ann, **self.plot_kws)
anno1.set_label(arg)
anno1.set_legend(self.legend.get(arg, False))
self.annotations.append(anno1)
if hasattr(ann, "set_label") and AnnotationBase.__subclasscheck__(
type(ann)
):
self.annotations.append(ann)
ann.set_label(arg)
ann.set_legend(self.legend.get(arg, False))
if type(ann) == anno_label and self.orientation is None:
if self.axis == 1 and len(self.labels) == 0:
self.orientation = "up"
elif self.axis == 1:
self.orientation = "down"
elif self.axis == 0 and len(self.labels) == 0:
self.orientation = "left"
elif self.axis == 0:
self.orientation = "right"
ann.set_orientation(self.orientation)
self.labels.append(arg)
def _set_orentation(self, orientation):
if self.orientation is None:
self.orientation = orientation
def _heights(self):
self.heights = [ann.height for ann in self.annotations]
def _nrows(self):
self.nrows = [ann.nrows for ann in self.annotations]
def _set_label_kws(self, label_kws, ticklabels_kws):
if self.label_side in ["left", "right"] and self.axis != 1:
raise ValueError(
"For columns annotation, label_side must be left or right!"
)
if self.label_side in ["top", "bottom"] and self.axis != 0:
raise ValueError("For row annotation, label_side must be top or bottom!")
if self.orientation is None:
if self.axis == 1:
self.orientation = "up"
else: # horizonal
self.orientation = "left"
self.label_kws = {} if label_kws is None else label_kws
self.ticklabels_kws = {} if ticklabels_kws is None else ticklabels_kws
self.label_kws.setdefault("rotation_mode", "anchor")
if self.label_side is None:
self.label_side = (
"right" if self.axis == 1 else "top"
) # columns annotation, default ylabel is on the right
ha, va = "left", "center"
if self.orientation == "left":
rotation, labelrotation = 90, 90
ha = "right" if self.label_side == "bottom" else "left"
elif self.orientation == "right":
ha = "right" if self.label_side == "top" else "left"
rotation, labelrotation = -90, -90
else: # self.orientation == 'up':
rotation, labelrotation = 0, 0
ha = "left" if self.label_side == "right" else "right"
self.label_kws.setdefault("rotation", rotation)
self.ticklabels_kws.setdefault("labelrotation", labelrotation)
self.label_kws.setdefault("horizontalalignment", ha)
self.label_kws.setdefault("verticalalignment", va)
map_dict = {"right": "left", "left": "right", "top": "bottom", "bottom": "top"}
self.ticklabels_side = map_dict[self.label_side]
# label_kws: alpha,color,fontfamily,fontname,fontproperties,fontsize,fontstyle,fontweight,label,rasterized,
# rotation,rotation_mode(default,anchor),visible, zorder,verticalalignment,horizontalalignment
[docs] def set_axes_kws(self):
if self.axis == 1 and self.label_side == "left":
self.ax.yaxis.tick_right()
for i in range(self.axes.shape[0]):
self.axes[i, 0].yaxis.set_visible(True)
self.axes[i, 0].yaxis.label.set_visible(True)
self.axes[i, 0].tick_params(
axis="y",
which="both",
left=False,
labelleft=False,
right=False,
labelright=False,
)
self.axes[i, 0].set_ylabel(self.annotations[i].label)
self.axes[i, 0].yaxis.set_label_position(self.label_side)
self.axes[i, 0].yaxis.label.update(self.label_kws)
# self.axes[i, -1].yaxis.tick_right() # ticks
if type(self.annotations[i]) not in [anno_simple,anno_img]:
self.axes[i, -1].yaxis.set_visible(True)
self.axes[i, -1].tick_params(
axis="y", which="both", right=True, labelright=True
)
self.axes[i, -1].yaxis.set_tick_params(**self.ticklabels_kws)
elif self.axis == 1 and self.label_side == "right":
self.ax.yaxis.tick_left()
for i in range(self.axes.shape[0]):
self.axes[i, -1].yaxis.set_visible(True)
self.axes[i, -1].yaxis.label.set_visible(True)
self.axes[i, -1].tick_params(
axis="y",
which="both",
left=False,
labelleft=False,
right=False,
labelright=False,
)
self.axes[i, -1].set_ylabel(self.annotations[i].label)
self.axes[i, -1].yaxis.set_label_position(self.label_side)
self.axes[i, -1].yaxis.label.update(self.label_kws)
# self.axes[i, 0].yaxis.tick_left() # ticks
if type(self.annotations[i]) not in [anno_simple,anno_img]:
self.axes[i, 0].yaxis.set_visible(True)
self.axes[i, 0].tick_params(
axis="y", which="both", left=True, labelleft=True
)
self.axes[i, 0].yaxis.set_tick_params(**self.ticklabels_kws)
elif self.axis == 0 and self.label_side == "top":
self.ax.xaxis.tick_bottom()
for j in range(self.axes.shape[1]):
self.axes[0, j].xaxis.set_visible(True) #0, the top axes
self.axes[0, j].xaxis.label.set_visible(True)
self.axes[0, j].tick_params(
axis="x",
which="both",
top=False,
labeltop=False,
bottom=False,
labelbottom=False,
)
self.axes[0, j].set_xlabel(self.annotations[j].label)
self.axes[0, j].xaxis.set_label_position(self.label_side)
self.axes[0, j].xaxis.label.update(self.label_kws)
# self.axes[-1, j].xaxis.tick_bottom() # ticks
if type(self.annotations[j]) not in [anno_simple,anno_img]:
self.axes[-1, j].xaxis.set_visible(True) # show ticks
self.axes[-1, j].tick_params(
axis="x", which="both", bottom=True, labelbottom=True
)
self.axes[-1, j].xaxis.set_tick_params(**self.ticklabels_kws)
elif self.axis == 0 and self.label_side == "bottom":
self.ax.xaxis.tick_top()
for j in range(self.axes.shape[1]):
self.axes[-1, j].xaxis.set_visible(True)
self.axes[-1, j].xaxis.label.set_visible(True)
self.axes[-1, j].tick_params(
axis="x",
which="both",
top=False,
labeltop=False,
bottom=False,
labelbottom=False,
)
self.axes[-1, j].set_xlabel(self.annotations[j].label)
self.axes[-1, j].xaxis.set_label_position(self.label_side)
self.axes[-1, j].xaxis.label.update(self.label_kws)
# self.axes[0, j].xaxis.tick_top() # ticks
if type(self.annotations[j]) not in [anno_simple,anno_img]:
self.axes[0, j].xaxis.set_visible(True)
self.axes[0, j].tick_params(
axis="x", which="both", top=True, labeltop=True
)
self.axes[0, j].xaxis.set_tick_params(**self.ticklabels_kws)
[docs] def collect_legends(self):
"""
Collect legends.
Returns
-------
None
"""
if self.verbose >= 1:
print("Collecting annotation legends..")
self.legend_list = [] # handles(dict) / cmap, title, kws
for annotation in self.annotations:
if not annotation.legend:
continue
legend_kws = annotation.legend_kws.copy()
# print(annotation.cmap,annotation)
if (
(annotation.cmap is None)
or (hasattr(annotation.cmap, "N") and annotation.cmap.N < 256)
or (
type(annotation.cmap) == str
and get_colormap(annotation.cmap).N < 256
)
):
color_dict = annotation.color_dict
if color_dict is None:
continue
self.legend_list.append(
[
annotation.color_dict,
annotation.label,
legend_kws,
len(annotation.color_dict),
"color_dict",
]
)
else:
if annotation.df.shape[1] == 1:
array = annotation.df.iloc[:, 0].values
else:
array = annotation.df.values
vmax = np.nanmax(array)
vmin = np.nanmin(array)
# print(vmax,vmin,annotation)
legend_kws.setdefault("vmin", round(vmin, 2))
legend_kws.setdefault("vmax", round(vmax, 2))
self.legend_list.append(
[annotation.cmap, annotation.label, legend_kws, 4, "cmap"]
)
if len(self.legend_list) > 1:
self.legend_list = sorted(self.legend_list, key=lambda x: x[3])
if self.label_side == "right":
self.label_max_width = max(
[ann.get_max_label_width() for ann in self.annotations]
)
else:
self.label_max_width = max(
[ann.get_ticklabel_width() for ann in self.annotations]
)
# self.label_max_height = max([ann.ax.yaxis.label.get_window_extent().height for ann in self.annotations])
[docs] def plot_annotations(
self, ax=None, subplot_spec=None, idxs=None, wspace=None, hspace=None
):
"""
Plot annotations
Parameters
----------
ax : ax
axes to plot the annotations.
subplot_spec : ax.figure.add_gridspec
object from ax.figure.add_gridspec or matplotlib.gridspec.GridSpecFromSubplotSpec.
idxs : list
index to reorder df and df of annotation class.
wspace : float
if wspace not is None, use wspace, else wspace would be calculated based on gap.
hspace : float
if hspace not is None, use hspace, else hspace would be calculated based on gap.
Returns
-------
self.ax
"""
# print(ax.figure.get_size_inches())
self._set_label_kws(self.label_kws, self.ticklabels_kws)
if self.verbose >= 1:
print("Starting plotting HeatmapAnnotations")
if ax is None:
self.ax = plt.gca()
else:
self.ax = ax
if idxs is None:
idxs = [self.annotations[0].plot_data.index.tolist()]
if self.axis == 1:
nrows = len(self.heights)
ncols = len(idxs)
height_ratios = self.heights
width_ratios = [len(idx) for idx in idxs]
wspace = (
self.wgap
* mm2inch
* self.ax.figure.dpi
/ (self.ax.get_window_extent().width / ncols)
if wspace is None
else wspace
) # 1mm=mm2inch inch
hspace = (
self.hgap
* mm2inch
* self.ax.figure.dpi
/ (self.ax.get_window_extent().height / nrows)
if hspace is None
else hspace
) # fraction of height
else:
nrows = len(idxs)
ncols = len(self.heights)
width_ratios = self.heights
height_ratios = [len(idx) for idx in idxs]
hspace = (
self.hgap
* mm2inch
* self.ax.figure.dpi
/ (self.ax.get_window_extent().height / nrows)
if hspace is None
else hspace
)
wspace = (
self.wgap
* mm2inch
* self.ax.figure.dpi
/ (self.ax.get_window_extent().width / ncols)
if wspace is None
else wspace
) # The amount of width reserved for space between subplots, expressed as a fraction of the average axis width
if subplot_spec is None:
self.gs = self.ax.figure.add_gridspec(
nrows,
ncols,
hspace=hspace,
wspace=wspace,
height_ratios=height_ratios,
width_ratios=width_ratios,
)
else: # this ax is a subplot of another bigger figure.
self.gs = matplotlib.gridspec.GridSpecFromSubplotSpec(
nrows,
ncols,
hspace=hspace,
wspace=wspace,
subplot_spec=subplot_spec,
height_ratios=height_ratios,
width_ratios=width_ratios,
)
self.axes = np.empty(shape=(nrows, ncols), dtype=object)
self.fig = self.ax.figure
self.ax.set_axis_off()
for j, idx in enumerate(idxs): # columns if axis=1, rows if axis=0
for i, ann in enumerate(self.annotations): #rows for axis=1, columns if axis=0
# axis=1: left -> right, axis=0: bottom -> top.
ann.reorder(idx)
gs = self.gs[i, j] if self.axis == 1 else self.gs[j, i]
# sharex = self.axes[0, j] if self.axis == 1 else self.axes[0, i]
# sharey = self.axes[i, 0] if self.axis == 1 else self.axes[j, 0]
sharex = self.axes[0, j] if self.axis == 1 else None
sharey = None if self.axis == 1 else self.axes[j, 0]
ax1 = self.ax.figure.add_subplot(gs, sharex=sharex, sharey=sharey)
if self.axis == 1:
ax1.set_xlim([0, len(idx)])
else:
ax1.set_ylim([0, len(idx)])
ann.plot(ax=ax1, axis=self.axis) #subplot_spec=gs
if self.axis == 1:
# ax1.yaxis.set_visible(False)
ax1.yaxis.label.set_visible(False)
ax1.tick_params(
left=False, right=False, labelleft=False, labelright=False
)
self.ax.spines["top"].set_visible(False)
self.ax.spines["bottom"].set_visible(False)
self.axes[i, j] = ax1
if self.orientation == "down":
ax1.invert_yaxis()
else: # horizonal
if type(ann) != anno_simple:
# if sharey, one y axis inverted will affect other y axis?
ax1.invert_yaxis() # 20230312 fix bug for inversed row order in anno_label.
ax1.xaxis.label.set_visible(False)
ax1.tick_params(
top=False, bottom=False, labeltop=False, labelbottom=False
)
self.ax.spines["left"].set_visible(False)
self.ax.spines["right"].set_visible(False)
self.axes[j, i] = ax1
if self.orientation == "left":
ax1.invert_xaxis()
self.set_axes_kws()
self.legend_list = None
if self.plot and self.plot_legend:
self.plot_legends(ax=self.ax)
# _draw_figure(self.ax.figure)
return self.ax
[docs] def show_ticklabels(self, labels, **kwargs):
ha, va = "left", "center"
if self.axis == 1:
ax = self.axes[-1, 0] if self.orientation == "up" else self.axes[0, 0]
rotation = -45 if self.orientation == "up" else 45
ax.xaxis.set_visible(True)
ax.xaxis.label.set_visible(True)
if self.orientation == "up":
ax.xaxis.set_ticks_position("bottom")
ax.tick_params(axis="both", which="both", bottom=True, labelbottom=True)
else:
ax.xaxis.set_ticks_position("top")
ax.tick_params(axis="both", which="both", top=True, labeltop=True)
else:
ax = self.axes[0, -1] if self.orientation == "left" else self.axes[0, 0]
rotation = 0
ax.yaxis.set_visible(True)
ax.yaxis.label.set_visible(True)
if self.orientation == "left":
ax.yaxis.set_ticks_position("right")
ax.tick_params(axis="both", which="both", right=True, labelright=True)
else:
ha = "right"
ax.yaxis.set_ticks_position("left")
ax.tick_params(axis="both", which="both", left=True, labelleft=True)
kwargs.setdefault("rotation", rotation)
kwargs.setdefault("ha", ha)
kwargs.setdefault("va", va)
kwargs.setdefault("rotation_mode", "anchor")
if self.axis == 1:
ax.set_xticklabels(labels, **kwargs)
else:
ax.set_yticklabels(labels, **kwargs)
[docs] def plot_legends(self, ax=None):
"""
Plot legends.
Parameters
----------
ax : axes for the plot, is ax is None, then ax=plt.figure()
Returns
-------
None
"""
if self.legend_list is None:
self.collect_legends()
if len(self.legend_list) > 0:
# if the legend is on the right side
space = (
self.label_max_width
if (self.legend_side == "right" and self.label_side == "right")
else 0
)
legend_hpad = (
self.legend_hpad * mm2inch * self.ax.figure.dpi
) # mm to inch to pixel
self.legend_axes, self.cbars, self.boundry = plot_legend_list(
self.legend_list,
ax=ax,
space=space + legend_hpad,
legend_side="right",
gap=self.legend_gap,
legend_width=self.legend_width,
legend_vpad=self.legend_vpad,
verbose=self.verbose
)
# =============================================================================