# -*- coding: utf-8 -*-
# !/usr/bin/env python3
import os, sys
import pandas as pd
import numpy as np
from scipy.cluster import hierarchy
import matplotlib
import matplotlib.pylab as plt
from .utils import (
_draw_figure,
mm2inch,
plot_legend_list,
despine,
_index_to_ticklabels,
get_colormap,
)
from .clustermap import ClusterMapPlotter
from .annotations import HeatmapAnnotation, anno_barplot
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
# =============================================================================
[docs]def oncoprint(
data,
ax=None,
colors=None,
cmap="Set1",
nvar=None,
aspect=None,
bgcolor="whitesmoke",
row_gap=1,
xticklabels_kws=None,
subplot_spec=None,
yticklabels_kws=None,
**plot_kws
):
"""
Plot dot heatmap using a dataframe matrix as input.
Parameters
----------
data : pd.DataFrame
input matrix (pandas.DataFrame), for each element in this dataframe, the value should be a list.
ax : ax
ax
colors : dict
colors to control the dot, keys should be the value in hue. if colors is a str, then colors will overwrite
the parameter `c`.
cmap : str of dict
control the colormap of the dot, if cmap is a dict, keys should be the values from hue dataframe.
If `cmap` is a str (such as 'Set1'), the parameter `colors` will overwrite the colors of dots.
If `cmap` wisas a dict, then this paramter will overwrite the `colors`, and colors can only control the
colors for markers.
bgcolor: str
background color, default is whitesmoke (#CCCCCC)
kwargs : dict
such as s,c,marker, s,marker and colors can also be pandas.DataFrame.
other kwargs passed to plt.scatter
Returns
-------
ax,axes:
"""
if ax is None:
ax = plt.gca()
nrows, ncols = data.shape
if nvar is None:
nvar = data.iloc[:, 0].apply(lambda x: len(x)).max()
if colors is None:
colors = [get_colormap(cmap)(i) for i in range(nvar)]
plot_kws.setdefault("width", 0.7)
plot_kws.setdefault("align", "center")
rowticklabels = _index_to_ticklabels(data.index)
colticklabels = _index_to_ticklabels(data.columns)
ax.set_ylim(0, nrows)
y_locater = list(np.arange(0.5, nrows, 1))
ax.yaxis.set_major_locator(plt.FixedLocator(y_locater))
ax.set_yticklabels(rowticklabels)
ax.invert_yaxis()
ax.set_xlim(0, ncols)
x_locater = list(np.arange(0.5, ncols, 1))
ax.xaxis.set_major_locator(plt.FixedLocator(x_locater))
ax.set_xticklabels(colticklabels)
ax.tick_params(
axis="both",
which="both",
left=False,
right=False,
labelleft=False,
labelright=False,
top=False,
bottom=False,
labeltop=False,
labelbottom=False,
)
xticklabels_kws = {} if xticklabels_kws is None else xticklabels_kws
yticklabels_kws = {} if yticklabels_kws is None else yticklabels_kws
xticklabels_kws.setdefault("labelrotation", -90)
xticklabels_kws.setdefault("labelbottom", True)
yticklabels_kws.setdefault("labelleft", True)
ax.xaxis.set_tick_params(**xticklabels_kws)
ax.yaxis.set_tick_params(**yticklabels_kws)
despine(ax=ax, left=True, bottom=True, right=True, top=True)
hspace = row_gap * mm2inch * ax.figure.dpi / (ax.get_window_extent().height / nrows)
if subplot_spec is None:
gs = ax.figure.add_gridspec(nrows, 1, hspace=hspace)
else:
gs = matplotlib.gridspec.GridSpecFromSubplotSpec(
nrows, 1, hspace=hspace, subplot_spec=subplot_spec
)
axes = np.empty(shape=(nrows, 1), dtype=object)
# ax.set_axis_off()
for i in range(nrows):
df = (
pd.DataFrame(data.iloc[i, :].values.tolist())
.apply(lambda x: x / x.sum(), axis=1)
.fillna(0)
)
ax1 = ax.figure.add_subplot(gs[i, 0], sharex=ax)
ax1.set_ylim([0, 1])
ax1.yaxis.set_major_locator(plt.FixedLocator([0.5]))
ax1.set_yticklabels([rowticklabels[i]])
if not aspect is None:
ax1.set_aspect(aspect)
base_coordinates = [0] * ncols
x = np.arange(0.5, ncols, 1)
for j, color in zip(range(df.shape[1]), colors):
ax1.bar(
x=x,
height=df.iloc[:, j].values,
bottom=base_coordinates,
color=color,
**plot_kws
)
base_coordinates = df.iloc[:, j].values + base_coordinates
# plot background colors
df = df.loc[df.sum(axis=1) == 0]
ax1.bar(
x=df.index.values + 0.5, height=[1] * df.shape[0], color=bgcolor, **plot_kws
)
# ax1.tick_params(axis='both', which='both',
# left=False, right=False, labelleft=False, labelright=False,
# top=False, bottom=False, labeltop=False, labelbottom=False)
# despine(ax=ax1, left=True, bottom=True, right=True, top=True)
ax1.set_axis_off()
axes[i, 0] = ax1
# secondary y tick labels: axes2 = axes1.twinx() #mirror
_draw_figure(ax.figure)
return ax, axes
# =============================================================================
[docs]class oncoPrintPlotter(ClusterMapPlotter):
"""
DotClustermap (Heatmap) plotter, inherited from ClusterMapPlotter.
Plot dot heatmap (clustermap) with annotation and legends.
Parameters
----------
data : dataframe
pandas dataframe or numpy array.
x: str
The column name in data.columns to be shown on the columns of heatmap / clustermap.
y : str
The column name in data.columns to be shown on the rows of heatmap / clustermap.
values : str or list
The column names in data.columns to control the sizes, or color of scatter (dot).
colors :list.
colors for each column in values
cmap :str or dict, optional.
If cmap is a dict, the keys should be the values from data[hue].values, and values should be cmap.
If cmap is a string, it should be colormap, such as 'Set1'.
bgcolor: str
background color, default is whitesmoke (or lightgray)
color_legend_kws: dict
legend_kws passed to plot_color_dict_legend
cmap_legend_kws: dict
legend_kws passed to plot_cmap_legend
kwargs :dict
Other kwargs passed to ClusterMapPlotter and oncoprint,
such as width, align (will be passed to oncoprint).
Returns
-------
oncoPrintPlotter.
"""
def __init__(
self,
data=None,
x=None,
y=None,
values=None,
cmap="Set1",
colors=None,
aspect=None,
bgcolor="whitesmoke",
row_gap=0.8,
color_legend_kws={},
cmap_legend_kws={},
remove_empty_rows=True,
remove_empty_columns=True,
**kwargs
):
kwargs["data"] = data
self.x = x
self.y = y
self.values = values if type(values) == list else [values]
self.cmap = cmap
self.colors = colors
self.aspect = aspect
self.bgcolor = bgcolor
self.row_gap = row_gap
self.color_legend_kws = color_legend_kws
self.cmap_legend_kws = cmap_legend_kws
self.remove_empty_rows = remove_empty_rows
self.remove_empty_cols = remove_empty_columns
super().__init__(**kwargs)
def _reorder_rows(self):
if self.verbose >= 1:
print("Reordering rows..")
if self.row_split is None and self.row_cluster:
self.row_order = [
self.row_vc.sum(axis=1).sort_values(ascending=False).index.tolist()
]
return None
elif isinstance(self.row_split, int) and self.row_cluster:
self.calculate_row_dendrograms(self.data2d.applymap(lambda x: np.sum(x)))
self.row_clusters = (
pd.Series(
hierarchy.fcluster(
self.dendrogram_row.linkage,
t=self.row_split,
criterion="maxclust",
),
index=self.data2d.index.tolist(),
)
.to_frame(name="cluster")
.groupby("cluster")
.apply(lambda x: x.index.tolist())
.to_dict()
)
# index=self.dendrogram_row.dendrogram['ivl']).to_frame(name='cluster')
elif isinstance(self.row_split, (pd.Series, pd.DataFrame)):
if isinstance(self.row_split, pd.Series):
self.row_split = self.row_split.to_frame(name=self.row_split.name).loc[
self.data2d.index.tolist()
]
cols = self.row_split.columns.tolist()
row_clusters = self.row_split.groupby(cols).apply(
lambda x: x.index.tolist()
)
if self.row_split_order:
row_clusters = row_clusters.loc[self.row_split_order]
self.row_clusters = row_clusters.to_dict()
elif not self.row_cluster:
self.row_order = [self.data2d.index.tolist()]
return None
else:
raise TypeError("row_split must be integar or dataframe or series")
self.row_order = []
self.dendrogram_rows = []
for i, cluster in enumerate(self.row_clusters):
rows = self.row_clusters[cluster]
if len(rows) <= 1:
self.row_order.append(rows)
self.dendrogram_rows.append(None)
continue
if self.row_cluster:
row_order = (
self.row_vc.loc[rows]
.sum(axis=1)
.sort_values(ascending=False)
.index.tolist()
)
self.row_order.append(row_order)
else:
self.row_order.append(rows)
[docs] def get_samples_order(self, data, row_order):
"""
data is a dataframe, row_order is a list ([[],[]]).
"""
nrows = data.shape[0]
row_orders = list(np.array(row_order).flatten())
# https://gist.github.com/armish/564a65ab874a770e2c26
col_order = (
data.apply(
lambda x: x.apply(
lambda y: 0
if np.sum(y) == 0
else 2 ** (nrows - row_orders.index(x.name) - 1)
),
axis=1,
)
.sum()
.sort_values(ascending=False)
.index.tolist()
)
return col_order
def _reorder_cols(self):
if self.verbose >= 1:
print("Reordering cols..")
if self.col_split is None and self.col_cluster:
col_order = self.get_samples_order(self.data2d, self.row_order)
self.col_order = [col_order] # self.data2d.iloc[:, xind].columns.tolist()
return None
if isinstance(self.col_split, int) and self.col_cluster:
self.calculate_col_dendrograms(self.data2d.applymap(lambda x: np.sum(x)))
self.col_clusters = (
pd.Series(
hierarchy.fcluster(
self.dendrogram_col.linkage,
t=self.col_split,
criterion="maxclust",
),
index=self.data2d.columns.tolist(),
)
.to_frame(name="cluster")
.groupby("cluster")
.apply(lambda x: x.index.tolist())
.to_dict()
)
elif isinstance(self.col_split, (pd.Series, pd.DataFrame)):
if isinstance(self.col_split, pd.Series):
self.col_split = self.col_split.to_frame(name=self.col_split.name).loc[
self.data2d.columns.tolist()
]
cols = self.col_split.columns.tolist()
col_clusters = self.col_split.groupby(cols).apply(
lambda x: x.index.tolist()
)
if self.col_split_order:
col_clusters = col_clusters.loc[self.col_split_order]
self.col_clusters = col_clusters.to_dict()
elif not self.col_cluster:
self.col_order = [self.data2d.columns.tolist()]
return None
else:
raise TypeError("row_split must be integar or dataframe or series")
self.col_order = []
self.dendrogram_cols = []
for i, cluster in enumerate(self.col_clusters):
cols = self.col_clusters[cluster]
if len(cols) <= 1:
self.col_order.append(cols)
self.dendrogram_cols.append(None)
continue
if self.col_cluster:
col_order = self.get_samples_order(
self.data2d.loc[:, cols], self.row_order
)
self.col_order.append(col_order)
else:
self.col_order.append(cols)
[docs] def plot_matrix(self, row_order, col_order):
if self.verbose >= 1:
print("Plotting matrix..")
nrows = len(row_order)
ncols = len(col_order)
self.col_split_gap_pixel = self.col_split_gap * mm2inch * self.ax.figure.dpi
self.wspace = (
(self.col_split_gap_pixel * ncols)
/ (
self.ax_heatmap.get_window_extent().width
+ self.col_split_gap_pixel - self.col_split_gap_pixel * ncols
)
)
self.row_split_gap_pixel = self.row_split_gap * mm2inch * self.ax.figure.dpi
self.hspace = (
(self.row_split_gap_pixel * nrows)
/ (
self.ax_heatmap.get_window_extent().height
+ self.row_split_gap_pixel - self.row_split_gap_pixel * nrows
)
)
self.heatmap_gs = matplotlib.gridspec.GridSpecFromSubplotSpec(
nrows,
ncols,
hspace=self.hspace,
wspace=self.wspace,
subplot_spec=self.gs[1, 1],
height_ratios=[len(rows) for rows in row_order],
width_ratios=[len(cols) for cols in col_order],
)
self.heatmap_axes = np.empty(shape=(nrows, ncols), dtype=object)
self.heatmap_subaxes = np.empty(shape=(nrows, ncols), dtype=object)
self.ax_heatmap.set_axis_off()
for i, rows in enumerate(row_order):
for j, cols in enumerate(col_order):
ax1 = self.ax_heatmap.figure.add_subplot(
self.heatmap_gs[i, j],
sharex=self.heatmap_axes[0, j],
sharey=self.heatmap_axes[i, 0],
)
kwargs = self.kwargs.copy()
ax2, axes = oncoprint(
self.data2d.loc[rows, cols],
colors=self.colors,
nvar=len(self.values),
aspect=self.aspect,
bgcolor=self.bgcolor,
row_gap=self.row_gap,
cmap=kwargs.pop("cmap", self.cmap),
ax=ax1,
subplot_spec=self.heatmap_gs[i, j],
**kwargs
)
self.heatmap_axes[i, j] = ax1
self.heatmap_subaxes[i, j] = axes
ax1.yaxis.label.set_visible(False)
ax1.xaxis.label.set_visible(False)
ax1.tick_params(
which="both",
left=False,
right=False,
labelleft=False,
labelright=False,
top=False,
bottom=False,
labeltop=False,
labelbottom=False,
)
[docs] def add_default_annotations(self):
if self.top_annotation is None:
self.top_annotation = HeatmapAnnotation(
axis=1,
Col=anno_barplot(self.col_vc, colors=self.colors, legend=False),
verbose=0,
label_side="left",
label_kws={"horizontalalignment": "right"},
)
else:
col_ann = anno_barplot(self.col_vc, colors=self.colors, legend=False)
col_ann.set_label("Col")
self.top_annotation.annotations.append(col_ann)
if self.right_annotation is None:
self.right_annotation = HeatmapAnnotation(
axis=0,
Row=anno_barplot(self.row_vc, colors=self.colors, legend=False),
verbose=0,
label_side="top",
label_kws={"horizontalalignment": "right"},
)
else:
row_ann = anno_barplot(self.row_vc, colors=self.colors, legend=False)
row_ann.set_label("Row")
self.right_annotation = [row_ann] + self.right_annotation.annotations
[docs] def collect_legends(self):
if self.verbose >= 1:
print("Collecting legends..")
self.legend_list = []
self.label_max_width = 0
for annotation in [
self.top_annotation,
self.bottom_annotation,
self.left_annotation,
self.right_annotation,
]:
if not annotation is None:
annotation.collect_legends()
if annotation.plot_legend and len(annotation.legend_list) > 0:
self.legend_list.extend(annotation.legend_list)
# print(annotation.label_max_width,self.label_max_width)
if annotation.label_max_width > self.label_max_width:
self.label_max_width = annotation.label_max_width
if self.legend:
self.legend_list.append(
[
self.color_dict,
self.label,
self.color_legend_kws,
len(self.color_dict),
"color_dict",
]
)
heatmap_label_max_width = (
max([label.get_window_extent().width for label in self.yticklabels])
if len(self.yticklabels) > 0 and self.row_names_side == "right"
else 0
)
if (
heatmap_label_max_width >= self.label_max_width
or self.legend_anchor == "ax_heatmap"
):
self.label_max_width = heatmap_label_max_width # * 1.1
if len(self.legend_list) > 1:
self.legend_list = sorted(self.legend_list, key=lambda x: x[3])
[docs] def post_processing(self):
pass
if __name__ == "__main__":
pass