import seaborn as sns
from matplotlib.lines import Line2D
import anndata
import pandas as pd
from .color import level_one_palette
from .contour import density_contour
from .text_anno_scatter import _text_anno_scatter
from .utilities import (
_make_tiny_axis_label,
_density_based_sample,
_extract_coords,
zoom_ax,
)
[docs]def categorical_scatter(
data,
ax,
# coords
coord_base="umap",
x=None,
y=None,
# color
hue=None,
palette="auto",
# text annotation
text_anno=None,
text_anno_kws=None,
text_anno_palette=None,
text_transform=None,
dodge_text=False,
dodge_kws=None,
# legend
show_legend=False,
legend_kws=None,
# size
s=5,
size=None,
sizes: dict = None,
size_norm=None,
# other
axis_format="tiny",
max_points=5000,
labelsize=4,
linewidth=0,
zoomxy=1.05,
outline=None,
outline_pad=3,
outline_kws=None,
scatter_kws=None,
return_fig=False,
rasterized='auto',
):
"""
Plot scatter plot with these options:
- Color by a categorical variable, and generate legend of the variable if needed
- Add text annotation using a categorical variable
- Circle categories with outlines
Parameters
----------
data
Dataframe that contains coordinates and categorical variables
ax
this function do not generate ax, must provide an ax
coord_base
coords name, if provided, will automatically search for x and y
x
x coord name
y
y coord name
hue
categorical col name or series for color hue
palette
palette of the hue, str or dict
text_anno
categorical col name or series for text annotation
text_anno_kws
text_anno_palette
text_transform
dodge_text
dodge_kws
show_legend
legend_kws
s
size
sizes
size_norm
axis_format
max_points
labelsize
linewidth
zoomxy
outline
outline_pad
outline_kws
scatter_kws
kws dict pass to sns.scatterplot
Returns
-------
"""
if isinstance(data, anndata.AnnData):
adata = data
data = adata.obs
x = f"{coord_base}_0"
y = f"{coord_base}_1"
_data = pd.DataFrame(
{
"x": adata.obsm[f"X_{coord_base}"][:, 0],
"y": adata.obsm[f"X_{coord_base}"][:, 1],
},
index=adata.obs_names,
)
else:
# add coords
_data, x, y = _extract_coords(data, coord_base, x, y)
# _data has 2 cols: "x" and "y"
# down sample plot data if needed.
if max_points is not None:
if _data.shape[0] > max_points:
_data = _density_based_sample(
_data, seed=1, size=max_points, coords=["x", "y"]
)
# determine rasterized
if rasterized == 'auto':
if _data.shape[0] > 200:
rasterized = True
else:
rasterized = False
# default scatter options
_scatter_kws = {"linewidth": 0,
"s": s,
"legend": None,
"palette": palette,
'rasterized': rasterized}
if scatter_kws is not None:
_scatter_kws.update(scatter_kws)
# deal with color
palette_dict = None
if hue is not None:
if isinstance(hue, str):
_data["hue"] = data[hue].copy()
else:
_data["hue"] = hue.copy()
hue = "hue"
_data["hue"] = _data["hue"].astype("category").cat.remove_unused_categories()
# deal with color palette
palette = _scatter_kws["palette"]
if isinstance(palette, str) or isinstance(palette, list):
palette_dict = level_one_palette(_data["hue"], order=None, palette=palette)
elif isinstance(palette, dict):
palette_dict = palette
else:
raise TypeError(
f"Palette can only be str, list or dict, " f"got {type(palette)}"
)
_scatter_kws["palette"] = palette_dict
# deal with size
if size is not None:
# discard s from _scatter_kws and use size in sns.scatterplot
_scatter_kws.pop("s")
sns.scatterplot(
x="x",
y="y",
data=_data,
ax=ax,
hue=hue,
size=size,
sizes=sizes,
size_norm=size_norm,
**_scatter_kws,
)
# deal with text annotation
if text_anno is not None:
if isinstance(text_anno, str):
_data["text_anno"] = data[text_anno].copy()
else:
_data["text_anno"] = text_anno.copy()
_text_anno_scatter(
data=_data[["x", "y", "text_anno"]],
ax=ax,
x="x",
y="y",
dodge_text=dodge_text,
dodge_kws=dodge_kws,
palette=text_anno_palette,
text_transform=text_transform,
anno_col="text_anno",
text_anno_kws=text_anno_kws,
labelsize=labelsize,
)
# deal with outline
if outline:
if isinstance(outline, str):
_data["outline"] = data[outline].copy()
else:
_data["outline"] = outline.copy()
_outline_kws = {
"linewidth": linewidth,
"palette": None,
"c": "lightgray",
"single_contour_pad": outline_pad,
}
if outline_kws is not None:
_outline_kws.update(outline_kws)
density_contour(
ax=ax, data=_data, x="x", y="y", groupby="outline", **_outline_kws
)
# clean axis
if axis_format == "tiny":
_make_tiny_axis_label(ax, x, y, arrow_kws=None, fontsize=labelsize)
elif (axis_format == "empty") or (axis_format is None):
sns.despine(ax=ax, left=True, bottom=True)
ax.set(xticks=[], yticks=[], xlabel=None, ylabel=None)
else:
pass
# deal with legend
if show_legend and (hue is not None):
n_hue = len(palette_dict)
_legend_kws = dict(
ncol=(1 if n_hue <= 20 else 2 if n_hue <= 40 else 3),
fontsize=labelsize,
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
if legend_kws is not None:
_legend_kws.update(legend_kws)
handles = []
labels = []
exist_hues = _data["hue"].unique()
for hue_name, color in palette_dict.items():
if hue_name not in exist_hues:
# skip hue_name that do not appear in the plot
continue
handle = Line2D(
[0],
[0],
marker="o",
color="w",
markerfacecolor=color,
markersize=_legend_kws["fontsize"],
)
handles.append(handle)
labels.append(hue_name)
_legend_kws["handles"] = handles
_legend_kws["labels"] = labels
ax.legend(**_legend_kws)
if zoomxy is not None:
zoom_ax(ax, zoomxy)
if return_fig:
return ax, _data
else:
return