import numpy as np
import seaborn as sns
from matplotlib.cm import get_cmap, ScalarMappable
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import copy
import anndata
import pandas as pd
from .color import plot_colorbar
from .contour import density_contour
from .text_anno_scatter import _text_anno_scatter
from .utilities import (
_density_based_sample,
_extract_coords,
_make_tiny_axis_label,
zoom_ax,
)
[docs]def tight_hue_range(hue_data, portion):
"""Automatic select a SMALLEST data range that covers [portion] of the data"""
hue_data = hue_data[np.isfinite(hue_data)]
hue_quantile = hue_data.quantile(q=np.arange(0, 1, 0.01))
min_window_right = (
hue_quantile.rolling(window=int(portion * 100))
.apply(lambda i: i.max() - i.min(), raw=True)
.idxmin()
)
min_window_left = max(0, min_window_right - portion)
vmin, vmax = tuple(hue_data.quantile(q=[min_window_left, min_window_right]))
if np.isfinite(vmin):
vmin = max(hue_data.min(), vmin)
else:
vmin = hue_data.min()
if np.isfinite(vmax):
vmax = min(hue_data.max(), vmax)
else:
vmax = hue_data.max()
if vmin == vmax:
return hue_data.min(), hue_data.max()
return vmin, vmax
[docs]def continuous_scatter(
data,
ax,
coord_base="umap",
x=None,
y=None,
scatter_kws=None,
hue=None,
hue_norm=None,
hue_portion=0.95,
cmap="viridis",
colorbar=True,
colorbar_label_kws=None,
size=None,
size_norm=None,
size_portion=0.95,
sizes=None,
sizebar=True,
text_anno=None,
dodge_text=False,
dodge_kws=None,
text_anno_kws=None,
text_anno_palette=None,
text_transform=None,
axis_format="tiny",
max_points=5000,
s=5,
labelsize=4,
linewidth=0.5,
cax=None,
zoomxy=1.05,
outline=None,
outline_kws=None,
outline_pad=2,
return_fig=False,
rasterized='auto',
):
if isinstance(data, anndata.AnnData):
adata = data
data = adata.obs
x = f"{coord_base}_0"
y = f"{coord_base}_1"
_data = pd.DataFrame(
{
"x": adata.obsm[f"X_{coord_base}"][:, 0],
"y": adata.obsm[f"X_{coord_base}"][:, 1],
},
index=adata.obs_names,
)
else:
# add coords
_data, x, y = _extract_coords(data, coord_base, x, y)
# _data has 2 cols: "x" and "y"
# down sample plot data if needed.
if max_points is not None:
if _data.shape[0] > max_points:
_data = _density_based_sample(
_data, seed=1, size=max_points, coords=["x", "y"]
)
# determine rasterized
if rasterized == 'auto':
if _data.shape[0] > 200:
rasterized = True
else:
rasterized = False
# default scatter options
_scatter_kws = {"linewidth": 0,
"s": s,
"legend": None,
'rasterized': rasterized}
if scatter_kws is not None:
_scatter_kws.update(scatter_kws)
# deal with color
if hue is not None:
if isinstance(hue, str):
_data["hue"] = data[hue].astype(float)
colorbar_label = hue
else:
_data["hue"] = hue.astype(float)
colorbar_label = hue.name
hue = "hue"
if hue_norm is None:
# get the smallest range that include "hue_portion" of data
hue_norm = tight_hue_range(_data["hue"], hue_portion)
# cnorm is the normalizer for color
cnorm = Normalize(vmin=hue_norm[0], vmax=hue_norm[1])
if isinstance(cmap, str):
# from here, cmap become colormap object
cmap = copy.copy(get_cmap(cmap))
cmap.set_bad(color=(0.5, 0.5, 0.5, 0.5))
else:
if not isinstance(cmap, ScalarMappable):
raise TypeError(
f"cmap can only be str or ScalarMappable, got {type(cmap)}"
)
else:
hue_norm = None
cnorm = None
colorbar_label = ""
# deal with size
if size is not None:
if isinstance(size, str):
_data["size"] = data[size].astype(float)
else:
_data["size"] = size.astype(float)
size = "size"
if size_norm is None:
# get the smallest range that include "size_portion" of data
size_norm = tight_hue_range(_data["size"], size_portion)
# snorm is the normalizer for size
size_norm = Normalize(vmin=size_norm[0], vmax=size_norm[1])
# replace s with sizes
s = _scatter_kws.pop("s")
if sizes is None:
sizes = (min(s, 1), s)
else:
size_norm = None
sizes = None
sns.scatterplot(
x="x",
y="y",
data=_data,
hue=hue,
palette=cmap,
hue_norm=cnorm,
size=size,
sizes=sizes,
size_norm=size_norm,
ax=ax,
**_scatter_kws,
)
if text_anno is not None:
if isinstance(text_anno, str):
_data["text_anno"] = data[text_anno]
else:
_data["text_anno"] = text_anno
_text_anno_scatter(
data=_data[["x", "y", "text_anno"]],
ax=ax,
x="x",
y="y",
dodge_text=dodge_text,
dodge_kws=dodge_kws,
palette=text_anno_palette,
text_transform=text_transform,
anno_col="text_anno",
text_anno_kws=text_anno_kws,
labelsize=labelsize,
)
# deal with outline
if outline:
if isinstance(outline, str):
_data["outline"] = data[outline]
else:
_data["outline"] = outline
_outline_kws = {
"linewidth": linewidth,
"palette": None,
"c": "lightgray",
"single_contour_pad": outline_pad,
}
if outline_kws is not None:
_outline_kws.update(outline_kws)
density_contour(
ax=ax, data=_data, x="x", y="y", groupby="outline", **_outline_kws
)
# clean axis
if axis_format == "tiny":
_make_tiny_axis_label(ax, x, y, arrow_kws=None, fontsize=labelsize)
elif (axis_format == "empty") or (axis_format is None):
sns.despine(ax=ax, left=True, bottom=True)
ax.set(xticks=[], yticks=[], xlabel=None, ylabel=None)
else:
pass
return_axes = [ax]
# make color bar
if colorbar and (hue is not None):
_colorbar_label_kws = dict(
fontsize=labelsize, label=hue, labelpad=10, rotation=270
)
if colorbar_label_kws is not None:
_colorbar_label_kws.update(colorbar_label_kws)
# small ax for colorbar
if cax is None:
cax = inset_axes(
ax, width="3%", height="25%", loc="lower right", borderpad=0
)
cax = plot_colorbar(
cax=cax,
cmap=cmap,
cnorm=cnorm,
hue_norm=hue_norm,
label=colorbar_label,
orientation="vertical",
labelsize=labelsize,
linewidth=0.5,
)
return_axes.append(cax)
# make size bar
if sizebar and (size is not None):
# TODO plot dot size bar
pass
if zoomxy is not None:
zoom_ax(ax, zoomxy)
if return_fig:
return tuple(return_axes), _data
else:
return