Source code for PyComplexHeatmap.utils

# -*- coding: utf-8 -*-
# !/usr/bin/env python3
"""Utility functions, for internal use."""
import numpy as np
import pandas as pd
import collections
import matplotlib
import matplotlib.pylab as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.lines as mlines
import matplotlib.patches as mpatches

mm2inch = 1 / 25.4


# =============================================================================
[docs]def set_default_style(): from matplotlib import rcParams D = { # 'font.family':['sans serif'], #'serif', # 'mathtext.fontset':'dejavuserif', # 'font.sans-serif':['Arial'], "pdf.fonttype": 42, # Remove legend frame "legend.frameon": True, "legend.fontsize": 10, # Savefig "figure.dpi": 100, "savefig.bbox": "tight", "savefig.dpi": 300, "savefig.pad_inches": 0.05, } rcParams.update(D)
# =============================================================================
[docs]def get_colormap(cmap): try: return plt.colormaps.get(cmap) # matplotlib >= 3.5.1? except: return plt.get_cmap(cmap) # matplotlib <=3.4.3?
def _check_mask(data, mask): """ Ensure that data and mask are compatible and add missing values and infinite values. Values will be plotted for cells where ``mask`` is ``False``. ``data`` is expected to be a DataFrame; ``mask`` can be an array or a DataFrame. Parameters ---------- data mask Returns ------- """ if mask is None: mask = np.zeros(data.shape, bool) if isinstance(mask, np.ndarray): if mask.shape != data.shape: raise ValueError("Mask must have the same shape as data.") mask = pd.DataFrame(mask, index=data.index, columns=data.columns, dtype=bool) elif isinstance(mask, pd.DataFrame): if not mask.index.equals(data.index) and mask.columns.equals(data.columns): err = "Mask must have the same index and columns as data." raise ValueError(err) # Add any cells with missing values or infinite values to the mask mask = mask | pd.isnull(data) | np.logical_not(np.isfinite(data)) return mask # ============================================================================= def _calculate_luminance(color): """ Calculate the relative luminance of a color according to W3C standards Parameters ---------- color : matplotlib color or sequence of matplotlib colors Hex code, rgb-tuple, or html color name. Returns ------- luminance : float(s) between 0 and 1 """ rgb = matplotlib.colors.colorConverter.to_rgba_array(color)[:, :3] rgb = np.where(rgb <= 0.03928, rgb / 12.92, ((rgb + 0.055) / 1.055) ** 2.4) lum = rgb.dot([0.2126, 0.7152, 0.0722]) try: return lum.item() except ValueError: return lum # =============================================================================
[docs]def define_cmap( plot_data, vmin=None, vmax=None, cmap=None, center=None, robust=True, na_col="white" ): """ Use some heuristics to set good defaults for colorbar and range. """ # plot_data is a np.ma.array instance # plot_data=np.ma.masked_where(np.asarray(plot_data), plot_data) # calc_data = plot_data.astype(float).filled(np.nan) if vmin is None: if robust: vmin = np.nanpercentile(plot_data, 2) else: vmin = np.nanmin(plot_data) if vmax is None: if robust: vmax = np.nanpercentile(plot_data, 98) else: vmax = np.nanmax(plot_data) # Choose default colormaps if not provided if cmap is None: if center is None: cmap = "jet" else: cmap = "exp1" if isinstance(cmap, str): cmap1 = matplotlib.cm.get_cmap(cmap).copy() elif isinstance(cmap, list): cmap1 = matplotlib.colors.ListedColormap(cmap) else: cmap1 = cmap cmap1.set_bad(color=na_col) # set the color for NaN values # Recenter a divergent colormap if center is not None: # bad = cmap1(np.ma.masked_invalid([np.nan]))[0] # set the first color as the na_color under = cmap1(-np.inf) over = cmap1(np.inf) under_set = under != cmap1(0) over_set = over != cmap1(cmap1.N - 1) vrange = max(vmax - center, center - vmin) normalize = matplotlib.colors.Normalize(center - vrange, center + vrange) cmin, cmax = normalize([vmin, vmax]) cc = np.linspace(cmin, cmax, 256) cmap1 = matplotlib.colors.ListedColormap(cmap1(cc)) # cmap1.set_bad(bad) if under_set: cmap1.set_under( under ) # set the color of -np.inf as the color for low out-of-range values. if over_set: cmap1.set_over(over) else: normalize = matplotlib.colors.Normalize(vmin, vmax) return cmap1, normalize
# =============================================================================
[docs]def despine(fig=None, ax=None, top=True, right=True, left=False, bottom=False): """ Remove the top and right spines from plot(s). Parameters ---------- fig : matplotlib figure, optional Figure to despine all axes of, defaults to the current figure. ax : matplotlib axes, optional Specific axes object to despine. Ignored if fig is provided. top, right, left, bottom : boolean, optional If True, remove that spine. Returns ------- None """ if fig is None and ax is None: axes = plt.gcf().axes elif fig is not None: axes = fig.axes elif ax is not None: axes = [ax] for ax_i in axes: for side in ["top", "right", "left", "bottom"]: is_visible = not locals()[side] ax_i.spines[side].set_visible(is_visible) if left and not right: # remove left, keep right maj_on = any(t.tick1line.get_visible() for t in ax_i.yaxis.majorTicks) min_on = any(t.tick1line.get_visible() for t in ax_i.yaxis.minorTicks) ax_i.yaxis.set_ticks_position("right") for t in ax_i.yaxis.majorTicks: t.tick2line.set_visible(maj_on) for t in ax_i.yaxis.minorTicks: t.tick2line.set_visible(min_on) if bottom and not top: maj_on = any(t.tick1line.get_visible() for t in ax_i.xaxis.majorTicks) min_on = any(t.tick1line.get_visible() for t in ax_i.xaxis.minorTicks) ax_i.xaxis.set_ticks_position("top") for t in ax_i.xaxis.majorTicks: t.tick2line.set_visible(maj_on) for t in ax_i.xaxis.minorTicks: t.tick2line.set_visible(min_on)
# ============================================================================= def _draw_figure(fig): """ Force draw of a matplotlib figure, accounting for back-compat. """ # See https://github.com/matplotlib/matplotlib/issues/19197 for context fig.canvas.draw() if fig.stale: try: fig.draw(fig.canvas.get_renderer()) except AttributeError: pass # =============================================================================
[docs]def axis_ticklabels_overlap(labels): """ Return a boolean for whether the list of ticklabels have overlaps. Parameters ---------- labels : list of matplotlib ticklabels Returns ------- overlap : boolean True if any of the labels overlap. """ if not labels: return False try: bboxes = [l.get_window_extent() for l in labels] overlaps = [b.count_overlaps(bboxes) for b in bboxes] return max(overlaps) > 1 except RuntimeError: # Issue on macos backend raises an error in the above code return False
# ============================================================================= # ============================================================================= def _skip_ticks(labels, tickevery): """ Return ticks and labels at evenly spaced intervals. """ n = len(labels) if tickevery == 0: ticks, labels = [], [] elif tickevery == 1: ticks, labels = np.arange(n) + 0.5, labels else: start, end, step = 0, n, tickevery ticks = np.arange(start, end, step) + 0.5 labels = labels[start:end:step] return ticks, labels # ============================================================================= def _auto_ticks(ax, labels, axis): """ Determine ticks and ticklabels that minimize overlap. """ transform = ax.figure.dpi_scale_trans.inverted() bbox = ax.get_window_extent().transformed(transform) size = [bbox.width, bbox.height][axis] axis = [ax.xaxis, ax.yaxis][axis] (tick,) = axis.set_ticks([0]) fontsize = tick.label1.get_size() max_ticks = int(size // (fontsize / 72)) if max_ticks < 1: return [], [] tick_every = len(labels) // max_ticks + 1 tick_every = 1 if tick_every == 0 else tick_every ticks, labels = _skip_ticks(labels, tick_every) return ticks, labels # =============================================================================
[docs]def to_utf8(obj): """ Return a string representing a Python object. Strings (i.e. type ``str``) are returned unchanged. Byte strings (i.e. type ``bytes``) are returned as UTF-8-decoded strings. For other objects, the method ``__str__()`` is called, and the result is returned as a string. Parameters ---------- obj : object Any Python object Returns ------- s : str UTF-8-decoded string representation of ``obj`` """ if isinstance(obj, str): return obj try: return obj.decode(encoding="utf-8") except AttributeError: # obj is not bytes-like return str(obj)
# ============================================================================= def _index_to_label(index): """ Convert a pandas index or multiindex to an axis label. """ if isinstance(index, pd.MultiIndex): return "-".join(map(to_utf8, index.names)) else: return index.name # ============================================================================= def _index_to_ticklabels(index): """ Convert a pandas index or multiindex into ticklabels. """ if isinstance(index, pd.MultiIndex): return ["-".join(map(to_utf8, i)) for i in index.values] else: return index.values # =============================================================================
[docs]def cluster_labels(labels=None, xticks=None, majority=True): """ Merge the adjacent labels into one. Parameters ---------- labels : a list of labels. xticks : a list of x or y ticks coordinates. majority: if majority=True, keep the labels with the largest clusters. Returns ------- labels,ticks: merged labels and ticks coordinates. Examples ------- labels=['A','A','B','B','A','C','C','B','B','B','C'] xticks=list(range(len(labels))) new_labels,x=cluster_labels(labels,xticks) """ clusters_x = collections.defaultdict(list) clusters_labels = {} scanned_labels = "" i = 0 for label, x in zip(labels, xticks): if label != scanned_labels: scanned_labels = label i += 1 clusters_labels[i] = scanned_labels clusters_x[i].append(x) if majority: cluster_size = collections.defaultdict(int) largest_cluster = {} for i in clusters_labels: if len(clusters_x[i]) > cluster_size[clusters_labels[i]]: cluster_size[clusters_labels[i]] = len(clusters_x[i]) largest_cluster[clusters_labels[i]] = i labels = [ clusters_labels[i] for i in clusters_x if i == largest_cluster[clusters_labels[i]] ] x = [ np.mean(clusters_x[i]) for i in clusters_x if i == largest_cluster[clusters_labels[i]] ] return labels, x labels = [clusters_labels[i] for i in clusters_x] x = [np.mean(clusters_x[i]) for i in clusters_x] return labels, x
# =============================================================================
[docs]def plot_color_dict_legend( D=None, ax=None, title=None, color_text=True, label_side="right", kws=None ): """ plot legned for color dict Parameters ---------- D: a dict, key is categorical variable, values are colors. ax: axes to plot the legend. title: title of legend. color_text: whether to change the color of text based on the color in D. label_side: right of left. kws: kws passed to plt.legend. Returns ------- ax.legend """ if ax is None: ax = plt.gca() lgd_kws = kws.copy() if not kws is None else {} # bbox_to_anchor=(x,-0.05) lgd_kws.setdefault("frameon", True) lgd_kws.setdefault("ncol", 1) lgd_kws["loc"] = "upper left" lgd_kws["bbox_transform"] = ax.figure.transFigure lgd_kws.setdefault("borderpad", 0.1 * mm2inch * 72) # 0.1mm lgd_kws.setdefault("markerscale", 1) lgd_kws.setdefault("handleheight", 1) # font size, units is points lgd_kws.setdefault("handlelength", 1) # font size, units is points lgd_kws.setdefault( "borderaxespad", 0.1 ) # The pad between the axes and legend border, in font-size units. lgd_kws.setdefault( "handletextpad", 0.3 ) # The pad between the legend handle and text, in font-size units. lgd_kws.setdefault( "labelspacing", 0.1 ) # gap height between two Patches, 0.05*mm2inch*72 lgd_kws.setdefault("columnspacing", 1) lgd_kws.setdefault("bbox_to_anchor", (0, 1)) if label_side == "left": lgd_kws.setdefault("markerfirst", False) align = "right" else: lgd_kws.setdefault("markerfirst", True) align = "left" availabel_height = ( ax.figure.get_window_extent().height * lgd_kws["bbox_to_anchor"][1] ) l = [ mpatches.Patch(color=c, label=l) for l, c in D.items() ] # kws:?mpatches.Patch; rasterized=True L = ax.legend(handles=l, title=title, **lgd_kws) ax.figure.canvas.draw() while L.get_window_extent().height > availabel_height: # ax.cla() print("Incresing ncol") lgd_kws["ncol"] += 1 if lgd_kws["ncol"] >= 3: print("More than 3 cols is not supported") L.remove() return None L = ax.legend(handles=l, title=title, **lgd_kws) ax.figure.canvas.draw() L._legend_box.align = align if color_text: for text in L.get_texts(): try: lum = _calculate_luminance(D[text.get_text()]) text_color = "black" if lum > 0.408 else D[text.get_text()] text.set_color(text_color) except: pass ax.add_artist(L) ax.grid(False) # print(availabel_height,L.get_window_extent().height) return L
# =============================================================================
[docs]def plot_cmap_legend( cax=None, ax=None, cmap="turbo", label=None, kws=None, label_side="right" ): """ Plot legend for cmap. Parameters ---------- cax : Axes into which the colorbar will be drawn. ax : axes to anchor. cmap : turbo, hsv, Set1, Dark2, Paired, Accent,tab20,exp1,exp2,meth1,meth2 label : title for legend. kws : kws passed to plt.colorbar. label_side : right or left. Returns ------- cbar: axes of legend """ label = "" if label is None else label cbar_kws = {} if kws is None else kws.copy() cbar_kws.setdefault("label", label) # cbar_kws.setdefault("aspect",3) cbar_kws.setdefault("orientation", "vertical") # cbar_kws.setdefault("use_gridspec", True) # cbar_kws.setdefault("location", "bottom") cbar_kws.setdefault("fraction", 1) cbar_kws.setdefault("shrink", 1) cbar_kws.setdefault("pad", 0) vmax = cbar_kws.pop("vmax", 1) vmin = cbar_kws.pop("vmin", 0) # print(vmin,vmax,'vmax,vmin') cax.set_ylim([vmin, vmax]) vcenter= (vmax + vmin) / 2 center=cbar_kws.pop("center",None) if center is None: cbar_kws.setdefault("ticks", [vmin, vcenter, vmax]) m = plt.cm.ScalarMappable( norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), cmap=cmap ) else: m = plt.cm.ScalarMappable( norm=matplotlib.colors.CenteredNorm(vcenter=center), cmap=cmap ) cax.yaxis.set_label_position(label_side) cax.yaxis.set_ticks_position(label_side) cbar = ax.figure.colorbar(m, cax=cax, **cbar_kws) # use_gridspec=True # cbar.outline.set_color('white') # cbar.outline.set_linewidth(2) # cbar.dividers.set_color('red') # cbar.dividers.set_linewidth(2) # ax.figure.tight_layout(rect=[cax.get_position().x0, 0, cax.get_position().x1, 1],h_pad=0,w_pad=0) # # cax.spines['top'].set_visible(False) # cax.spines['bottom'].set_visible(False) # f = cbar.ax.get_window_extent().height / cax.get_window_extent().height return cbar
# =============================================================================
[docs]def plot_marker_legend( obj=None, ax=None, title=None, color_text=True, label_side="right", kws=None ): """ plot legned for different marker Parameters ---------- D: a dict, key is categorical variable, values are marker. ax: axes to plot the legend. title: title of legend. color_text: whether to change the color of text based on the color in D. label_side: right of left. kws: kws passed to plt.legend. Returns ------- ax.legend """ if ax is None: ax = plt.gca() markers, colors, ms = obj # markers = {'A': 'o', 'B': 's', 'C': 'D'} # color_dict = {'A': 'red', 'B': 'blue', 'C': 'green'} if colors is None: colors = "black" elif type(colors) == dict: color_dict = colors if type(colors) == str: color_dict = {} for k in markers: color_dict[k] = colors lgd_kws = kws.copy() if not kws is None else {} # bbox_to_anchor=(x,-0.05) lgd_kws.setdefault("frameon", True) lgd_kws.setdefault("ncol", 1) lgd_kws["loc"] = "upper left" lgd_kws["bbox_transform"] = ax.figure.transFigure lgd_kws.setdefault("borderpad", 0.1 * mm2inch * 72) # 0.1mm if ms is None: s = lgd_kws.pop("markersize", 10) ms_dict = {} # key is label (markers.keys), values is markersize. for k in markers: ms_dict[k] = s elif type(ms) != dict: ms_dict = {} for k in markers: ms_dict[k] = ms else: ms_dict = ms lgd_kws.setdefault("markerscale", 1) lgd_kws.setdefault("handleheight", 1) # font size, units is points lgd_kws.setdefault("handlelength", 1) # font size, units is points lgd_kws.setdefault( "borderaxespad", 0.1 ) # The pad between the axes and legend border, in font-size units. lgd_kws.setdefault( "handletextpad", 0.3 ) # The pad between the legend handle and text, in font-size units. lgd_kws.setdefault( "labelspacing", 0.5 ) # gap height between two Patches, 0.05*mm2inch*72 lgd_kws.setdefault("columnspacing", 1) lgd_kws.setdefault("bbox_to_anchor", (0, 1)) if label_side == "left": lgd_kws.setdefault("markerfirst", False) align = "right" else: lgd_kws.setdefault("markerfirst", True) align = "left" availabel_height = ( ax.figure.get_window_extent().height * lgd_kws["bbox_to_anchor"][1] ) # print(ms_dict,markers) L = [ mlines.Line2D( [], [], color=color_dict.get(l, "black"), marker=m, linestyle="None", markersize=ms_dict.get(l, 10), label=l, ) for l, m in markers.items() ] # kws:?mpatches.Patch; rasterized=True ms = lgd_kws.pop("markersize", 10) Lgd = ax.legend(handles=L, title=title, **lgd_kws) ax.figure.canvas.draw() while Lgd.get_window_extent().height > availabel_height: print("Incresing ncol") lgd_kws["ncol"] += 1 if lgd_kws["ncol"] >= 3: print("More than 3 cols is not supported") Lgd.remove() return None Lgd = ax.legend(handles=L, title=title, **lgd_kws) ax.figure.canvas.draw() Lgd._legend_box.align = align if color_text: for text in Lgd.get_texts(): try: lum = _calculate_luminance(color_dict[text.get_text()]) text_color = "black" if lum > 0.408 else color_dict[text.get_text()] text.set_color(text_color) except: pass ax.add_artist(Lgd) ax.grid(False) # print(availabel_height,L.get_window_extent().height) return Lgd
# =============================================================================
[docs]def cal_legend_width(legend_list): lgd_w = 4.5 if isinstance(plt.rcParams["legend.fontsize"],int): lgd_fontsize=plt.rcParams["legend.fontsize"] else: lgd_fontsize=10 legend_width = 0 for lgd in legend_list: obj, title, legend_kws, n, lgd_t = lgd if lgd_t == "color_dict": max_text_len = max(len(str(title)), max([len(str(k)) for k in obj])) fontsize = legend_kws.get("fontsize", plt.rcParams["legend.fontsize"]) lgd_w = ( max_text_len * fontsize * 0.65 / 72 / mm2inch ) # point to inches to mm. in average, width = height * 0.6 elif lgd_t == "markers": max_text_len = len(str(title)) fontsize = legend_kws.get("fontsize", lgd_fontsize) lgd_w = max_text_len * fontsize * 0.65 / 72 / mm2inch if legend_width < lgd_w: legend_width = lgd_w return legend_width
[docs]def plot_legend_list( legend_list=None, ax=None, space=0, legend_side="right", y0=None, gap=2, delta_x=None, legend_width=None, legend_vpad=5, cmap_width=4.5, verbose=1 ): """ Plot all lengends for a given legend_list. Parameters ---------- legend_list : a list including [handles(dict) / cmap / markers dict, title, legend_kws, height, legend_type] ax :axes to plot. space : unit is pixels. legend_side :right, or left y0 : the initate coordinate of y for the legend. gap : gap between legends, default is 2mm. legend_width: width of the legend, default is 4.5mm. Returns ------- legend_axes,boundry: """ if ax is None: print("No ax was provided, using plt.gca()") ax = plt.gca() ax.set_axis_off() left = ( ax.get_position().x0 + ax.yaxis.labelpad * 2 / ax.figure.get_window_extent().width if delta_x is None else ax.get_position().x0 + delta_x ) else: # labelpad: Spacing in points, pad is the fraction relative to x1. pad = ( (space + ax.yaxis.labelpad * 1.2 * ax.figure.dpi / 72) / ax.figure.get_window_extent().width if delta_x is None else delta_x ) # labelpad unit is points left = ax.get_position().x1 + pad if legend_width is None: try: legend_width = ( cal_legend_width(legend_list) + 3 ) # base width for color rectangle is set to 3 mm if verbose > 0: print(f"Estimated legend width: {legend_width} mm") except: legend_width=15 legend_width = ( legend_width * mm2inch * ax.figure.dpi / ax.figure.get_window_extent().width ) # mm to px to fraction cmap_width = ( cmap_width * mm2inch * ax.figure.dpi / ax.figure.get_window_extent().width ) # mm to px to fraction if legend_side == "right": ax_legend = ax.figure.add_axes( [left, ax.get_position().y0, legend_width, ax.get_position().height] ) # left, bottom, width, height legend_axes = [ax_legend] cbars = [] leg_pos = ax_legend.get_position() # left bototm: x0,y0; top right: x1,y1 # y is the bottom position of the first legend (from top to the bottom) y = ( leg_pos.y1 - legend_vpad * mm2inch * ax.figure.dpi / ax.figure.get_window_extent().height if y0 is None else y0 ) lgd_col_max_width = 0 # the maximum width of all legends in one column v_gap = round( gap * mm2inch * ax.figure.dpi / ax.figure.get_window_extent().height, 2 ) # 2mm vertically height gap between two legends i = 0 while i <= len(legend_list) - 1: obj, title, legend_kws, n, lgd_t = legend_list[i] ax1 = legend_axes[-1] # ax for the legend on the right ax1.set_axis_off() color_text = legend_kws.pop("color_text", True) if lgd_t == "cmap": # type(obj)==str: # a cmap, plot colorbar f = ( 15 * mm2inch * ax.figure.dpi / ax.figure.get_window_extent().height ) # 15 mm if y - f < 0: # add a new column of axes to plot legends offset = ( lgd_col_max_width + ax.yaxis.labelpad * 2 ) / ax.figure.get_window_extent().width ax2 = ax.figure.add_axes( rect=[ ax1.get_position().x0 + offset, ax.get_position().y0, cmap_width, ax.get_position().height, ] ) # left_pos.width legend_axes.append(ax2) ax1 = legend_axes[-1] ax1.set_axis_off() leg_pos = ax1.get_position() y = ( leg_pos.y1 - legend_vpad * mm2inch * ax.figure.dpi / ax.figure.get_window_extent().height if y0 is None else y0 ) lgd_col_max_width = 0 cax = ax1.figure.add_axes( rect=[leg_pos.x0, y - f, cmap_width, f], xmargin=0, ymargin=0 ) # unit is fractions of figure width and height # [i.set_linewidth(0.5) for i in cax.spines.values()] cax.figure.subplots_adjust(bottom=0) # wspace=0, hspace=0 # https://matplotlib.org/stable/api/figure_api.html # [left, bottom, width, height],sharex=True,anchor=(0,0),frame_on=False. cbar = plot_cmap_legend( ax=ax1, cax=cax, cmap=obj, label=title, label_side=legend_side, kws=legend_kws, ) cbar_width = cbar.ax.get_window_extent().width cbars.append(cbar) if cbar_width > lgd_col_max_width: lgd_col_max_width = cbar_width elif lgd_t == "color_dict": # print(obj, title, legend_kws) legend_kws["bbox_to_anchor"] = ( leg_pos.x0, y, ) # lower left position of the box. # x, y, width, height #kws['bbox_transform'] = ax.figure.transFigure # ax1.scatter(leg_pos.x0,y,s=6,color='red',zorder=20,transform=ax1.figure.transFigure) L = plot_color_dict_legend( D=obj, ax=ax1, title=title, label_side=legend_side, color_text=color_text, kws=legend_kws, ) if L is None: print("Legend too long, generating a new column..") pad = ( lgd_col_max_width + ax.yaxis.labelpad * 2 ) / ax.figure.get_window_extent().width left_pos = ax1.get_position() ax2 = ax.figure.add_axes( [ left_pos.x0 + pad, ax.get_position().y0, left_pos.width, ax.get_position().height, ] ) legend_axes.append(ax2) ax1 = legend_axes[-1] ax1.set_axis_off() leg_pos = ax1.get_position() y = ( leg_pos.y1 - legend_vpad * mm2inch * ax.figure.dpi / ax.figure.get_window_extent().height if y0 is None else y0 ) legend_kws["bbox_to_anchor"] = (leg_pos.x0, y) L = plot_color_dict_legend( D=obj, ax=ax1, title=title, label_side=legend_side, color_text=color_text, kws=legend_kws, ) lgd_col_max_width = 0 L_width = L.get_window_extent().width if L_width > lgd_col_max_width: lgd_col_max_width = L_width f = L.get_window_extent().height / ax.figure.get_window_extent().height cbars.append(L) elif lgd_t == "markers": legend_kws["bbox_to_anchor"] = ( leg_pos.x0, y, ) # lower left position of the box. L = plot_marker_legend( obj=obj, ax=ax1, title=title, label_side=legend_side, color_text=color_text, kws=legend_kws, ) # obj is a tuple: markers and colors if L is None: print("Legend too long, generating a new column..") pad = ( lgd_col_max_width + ax.yaxis.labelpad * 2 ) / ax.figure.get_window_extent().width left_pos = ax1.get_position() ax2 = ax.figure.add_axes( [ left_pos.x0 + pad, ax.get_position().y0, left_pos.width, ax.get_position().height, ] ) legend_axes.append(ax2) ax1 = legend_axes[-1] ax1.set_axis_off() leg_pos = ax1.get_position() y = ( leg_pos.y1 - legend_vpad * mm2inch * ax.figure.dpi / ax.figure.get_window_extent().height if y0 is None else y0 ) legend_kws["bbox_to_anchor"] = (leg_pos.x0, y) L = plot_marker_legend( obj=obj, ax=ax1, title=title, label_side=legend_side, color_text=color_text, kws=legend_kws, ) lgd_col_max_width = 0 L_width = L.get_window_extent().width if L_width > lgd_col_max_width: lgd_col_max_width = L_width f = L.get_window_extent().height / ax.figure.get_window_extent().height cbars.append(L) y = y - f - v_gap i += 1 if legend_side == "right": boundry = ( ax1.get_position().y1 + lgd_col_max_width / ax.figure.get_window_extent().width ) else: boundry = ( ax1.get_position().y0 - lgd_col_max_width / ax.figure.get_window_extent().width ) return legend_axes, cbars, boundry
# ============================================================================= set_default_style()