import seaborn as sns
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize
from scipy.cluster.hierarchy import dendrogram as _dendrogram
import matplotlib as mpl
[docs]def straight_branch(ax, a, b, plot_kws):
"""Draw link line between a and b"""
a_x, ay = a
bx, by = b
branch_x = [a_x, bx, bx]
branch_y = [ay, ay, by]
if plot_kws is None:
plot_kws = {}
return ax.plot(branch_x, branch_y, **plot_kws)
[docs]def plot_dendrogram(
linkage_df,
ax,
dendro=None,
labels=None,
dendro_kws=None,
plot_node_id=False,
plot_non_singleton=True,
plot_kws=None,
node_hue=None,
node_hue_norm=None,
node_hue_cbar=True,
node_hue_cbar_frac=0.1,
node_palette="viridis", # shared by both line and node hue
node_size=None,
node_size_norm=None,
node_sizes=None,
line_hue=None,
line_hue_norm=None,
line_palette="gray_r",
linewidth=1.5,
edge_color="gray",
marker_size=60,
marker_color="lightblue",
):
"""
Parameters
----------
linkage_df
dendro
labels
dendro_kws
ax
plot_node_id
plot_non_singleton
plot_kws
node_hue
node_hue_norm
node_hue_cbar
node_hue_cbar_frac
node_palette
node_size
node_size_norm
node_sizes
line_hue
line_hue_norm
line_palette
linewidth
edge_color
marker_size
marker_color
Returns
-------
"""
if plot_kws is None:
plot_kws = {}
if dendro is None:
if labels is None or linkage_df is None:
raise ValueError(
"linkage_df and labels must be provided to calculate dendrogram."
)
print("Computing dendrogram")
_dendro_kws = dict(no_plot=True)
if dendro_kws is not None:
_dendro_kws.update(dendro_kws)
# all we need is the leaves order from dendrogram,
# bellow we recalculate the node position to match the node id,
# so we can control each node
dendro = _dendrogram(linkage_df, labels=labels, **_dendro_kws)
else:
labels = dendro["ivl"]
n_leaves = len(dendro["leaves"])
node_pos = {} # all node including singleton and non-singleton
direct_link_map = {} # node linkage, keys only contain non-singleton
for leaf_x, leaf in enumerate(dendro["leaves"]):
# add singleton positions first
node_pos[int(leaf)] = (leaf_x, 0)
for i, (idx, (left, right, height, _)) in enumerate(linkage_df.iterrows()):
node_id = int(i + linkage_df.shape[0] + 1)
left = int(left)
right = int(right)
node_x = (node_pos[left][0] + node_pos[right][0]) / 2
node_pos[node_id] = [node_x, height]
direct_link_map[node_id] = [int(left), int(right)]
# ------------------ Plot nodes ------------------
# node colors
nan_color = "#D3D3D3" if marker_color is None else marker_color
if isinstance(node_palette, dict):
# categorical node color
node_colors = {
node: node_palette[node] if (node in node_palette) else nan_color
for node in node_pos.keys()
}
else:
# continous node color
if node_hue is not None:
if node_hue_norm is None:
values = node_hue.values
_hue_norm = Normalize(vmin=min(values), vmax=max(values))
else:
_hue_norm = Normalize(vmin=min(node_hue_norm), vmax=max(node_hue_norm))
_cmap = get_cmap(node_palette)
def node_cmap(v):
return (_cmap(_hue_norm(v)),)
if node_hue_cbar:
ax.figure.colorbar(
mpl.cm.ScalarMappable(norm=_hue_norm, cmap=_cmap),
ax=ax.figure.axes,
shrink=0.6,
fraction=node_hue_cbar_frac,
label="Node Color",
)
else:
node_hue = {}
def node_cmap(_):
return nan_color
node_colors = {
node: node_cmap(node_hue[node]) if (node in node_hue) else nan_color
for node in node_pos.keys()
}
# node sizes
nan_size = marker_size
if node_size is not None:
if node_sizes is None:
node_sizes = (marker_size, marker_size * 2)
if node_size_norm is None:
values = node_size.values
_size_norm = Normalize(vmin=min(values), vmax=max(values))
else:
_size_norm = Normalize(vmin=min(node_size_norm), vmax=max(node_size_norm))
def node_smap(v):
v_norm = _size_norm(v)
v_norm = min(1, max(0, v_norm)) # limit norm value to [0, 1]
s = v_norm * (node_sizes[1] - node_sizes[0]) + node_sizes[0]
return s
else:
node_size = {}
def node_smap(_):
return nan_size
node_sizes = {
node: node_smap(node_size[node]) if (node in node_size) else nan_size
for node in node_pos.keys()
}
# plot nodes
for node_id, (node_x, node_y) in node_pos.items():
if (node_id > len(dendro["leaves"]) - 1) and not plot_non_singleton:
break
ax.scatter(
node_x,
node_y,
s=node_sizes[node_id],
c=node_colors[node_id],
clip_on=False,
zorder=3,
)
# ------------------ Plot edges and node id ------------------
# line color
nan_color = "#D3D3D3" if edge_color is None else edge_color
if line_hue is not None:
if line_hue_norm is None:
values = line_hue.values
_hue_norm = Normalize(vmin=min(values), vmax=max(values))
else:
_hue_norm = Normalize(vmin=min(line_hue_norm), vmax=max(line_hue_norm))
_cmap = get_cmap(line_palette)
def line_cmap(v):
return _cmap(_hue_norm(v))
else:
line_hue = {}
def line_cmap(_):
return nan_color
line_colors = {
node: line_cmap(line_hue[node]) if (node in line_hue) else nan_color
for node in node_pos.keys()
}
ymax = 0
for node_id, (node_x, node_y) in node_pos.items():
ymax = max(ymax, node_y)
# plot node id text
if plot_node_id:
if node_id >= n_leaves:
ax.text(
node_x,
node_y,
node_id,
fontsize=6 if "fontsize" not in plot_kws else plot_kws["fontsize"],
ha="center",
va="center",
c="k",
)
else:
ax.text(
node_x,
-0.01,
node_id,
fontsize=6 if "fontsize" not in plot_kws else plot_kws["fontsize"],
ha="center",
va="center",
c="k",
)
# plot branch
# only non-singleton node has branch:
if node_id in direct_link_map:
# get child
left_child, right_child = direct_link_map[node_id]
# plot left branch
straight_branch(
ax,
(node_x, node_y),
node_pos[left_child],
plot_kws=dict(
c=line_colors[left_child], linewidth=linewidth, clip_on=False
),
)
# plot right branch
straight_branch(
ax,
(node_x, node_y),
node_pos[right_child],
plot_kws=dict(
c=line_colors[right_child], linewidth=linewidth, clip_on=False
),
)
ax.set_ylim(0, ymax)
ax.set_xlim(-0.5, len(labels) - 0.5)
ax.set(xticks=range(len(labels)), xticklabels=dendro["ivl"])
ax.xaxis.set_tick_params(rotation=90)
sns.despine(ax=ax, bottom=True, offset=15)
return dendro, node_pos