import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib import ticker
[docs]def plot_on_plate(
data,
value_col,
groupby,
ncols=4,
plate_base=384,
figsize=(5, 3.5),
row_base="Row384",
col_base="Col384",
vmin=0,
vmax=1,
heatmap_kws=None,
aggregation_func=None,
):
"""
Plot metadata into 384 or 96 plate view (heatmap)
Parameters
----------
data
dataframe contain all kinds of metadata
value_col
value to be plotted on plate view
groupby
groupby column, typically groupby plate id column(s) to plot each plate separately
ncols
number of column for axes, nrows will be calculated accordingly
plate_base
{384, 96} size of the plate view
figsize
matplotlib.Figure figsize
vmin
cmap vmin
vmax
cmap vmax
heatmap_kws
kws pass to sns.heatmap
aggregation_func
apply to reduce rows after groupby if the row is not unique
"""
if plate_base == 384:
plate_nrows, plate_ncols = 16, 24
elif plate_base == 96:
plate_nrows, plate_ncols = 8, 12
else:
raise ValueError(f"Plate base {plate_base} unknown")
heatmap_data_list = []
heatmap_names = []
for plate, sub_df in data.groupby(groupby):
# check if plate base are duplicated
duplicated = sub_df[[row_base, col_base]].duplicated().sum() != 0
if duplicated:
if aggregation_func is None:
raise ValueError(
"Row after groupby is not unique, aggregation_func can not be None"
)
heatmap_data = (
sub_df.groupby([row_base, col_base])[value_col]
.apply(aggregation_func)
.unstack()
)
else:
heatmap_data = sub_df.set_index([row_base, col_base])[value_col].unstack()
# reindex to make sure heatmap data in the shape of plate
heatmap_data.index = range(heatmap_data.shape[0])
heatmap_data.columns = range(heatmap_data.shape[1])
heatmap_data = heatmap_data.reindex(
index=list(range(plate_nrows)), columns=list(range(plate_ncols))
)
heatmap_data_list.append(heatmap_data)
if isinstance(plate, str):
heatmap_names.append(plate)
else:
heatmap_names.append("\n".join(plate))
nrows = round(len(heatmap_data_list) / ncols)
fig, axes = plt.subplots(
figsize=(figsize[0] * ncols, figsize[1] * nrows), ncols=ncols, nrows=nrows
)
fig.suptitle(
f"{value_col} on {len(heatmap_data_list)}*{plate_base} plates \n Color Range [{vmin}, {vmax}]",
fontsize=16,
)
cmap = plt.cm.viridis
cmap.set_under(color="#440154")
cmap.set_over(color="#FDE725")
cmap.set_bad(color="#FFFFFF")
if heatmap_kws is None:
heatmap_kws = {}
for heatmap_data, heatmap_name, ax in zip(
heatmap_data_list, heatmap_names, np.ravel(axes)
):
sns.heatmap(heatmap_data, vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, **heatmap_kws)
ax.set(title=heatmap_name, ylim=(-0.5, plate_nrows + 0.5))
fig.tight_layout(rect=[0, 0.05, 1, 0.95])
return fig, axes
[docs]def simple_violin(data, x, y, rotate_x=True):
fig, ax = plt.subplots(figsize=(5, 3))
sns.violinplot(x=x, y=y, data=data, ax=ax)
if rotate_x:
ax.tick_params(axis="x", rotation=90)
return fig, ax
[docs]def cutoff_vs_cell_remain(
data, name="", xlim_quantile=(0.001, 0.999), ylim=None, bins=100
):
xlim = tuple(np.quantile(data, xlim_quantile))
x = np.linspace(xlim[0], xlim[1], 500)
count_list = np.array([(data > i).sum() for i in x])
original_total_data = data.size
count_list = count_list / original_total_data * 100
data = data[(data < xlim[1]) & (data > xlim[0])]
fig, ax1 = plt.subplots()
try:
ax1 = sns.histplot(data, bins=bins, kde=False, ax=ax1)
except AttributeError:
# old seaborn version
ax1 = sns.distplot(a=data, bins=bins, kde=False, ax=ax1)
ax1.set_xlim(xlim)
ax1.set_xlabel(name)
if ylim is not None:
ax1.set_ylim(*ylim)
ax2 = ax1.twinx()
ax2.plot(x, count_list, linewidth=2, linestyle="--", c="r")
ax2.set_ylabel("% of Data Pass Filter", color="r")
ax2.grid()
return fig, (ax1, ax2)
[docs]def success_vs_fail(data, filter_col, filter_cutoff, x, y, ax):
use_data = data.copy()
use_data["filter"] = (data[filter_col] > filter_cutoff).apply(
lambda i: "Success" if i else "Fail"
)
sns.violinplot(x=x, y=y, data=use_data, hue="filter", ax=ax)
ax.tick_params(axis="x", rotation=90)
return ax
[docs]def plot_dispersion(
data,
hue="gene_subset",
zlab="dispersion",
data_quantile=(0.01, 0.99),
save_animate_path=None,
fig_kws=None,
):
from mpl_toolkits.mplot3d import Axes3D
if Axes3D.__doc__:
# touch the Axes3D to prevent ide remove it...
pass
@ticker.FuncFormatter
def mean_formatter(x, pos):
return f"{x:.1f}"
_fig_kws = dict(figsize=(12, 4), dpi=160)
if fig_kws is not None:
_fig_kws.update(fig_kws)
x = data["mean"]
y = data["cov"]
z = data[zlab]
xlim = tuple(np.quantile(x, data_quantile))
ylim = tuple(np.quantile(y, data_quantile))
zlim = tuple(np.quantile(z, data_quantile))
# directly apply lim on df
_df = data[
(x < xlim[1])
& (x > xlim[0])
& (y < ylim[1])
& (y > ylim[0])
& (z < zlim[1])
& (z > zlim[0])
]
color_dict = {True: "steelblue", False: "lightgray"}
color = _df[hue].map(color_dict).tolist()
fig = plt.figure(**_fig_kws)
if save_animate_path is None:
axes = [fig.add_subplot(int(f"13{i}"), projection="3d") for i in range(1, 4)]
view_inits = [(10, 10), (80, 45), (10, 80)]
for i, (ax, view) in enumerate(zip(axes, view_inits)):
ax.scatter(_df["mean"], _df["cov"], _df[zlab], c=color, s=0.2, alpha=0.6)
if i == 0:
ax.set_xlabel("Mean", labelpad=10)
ax.set_xticklabels([])
ax.set_ylabel("Cov", labelpad=10)
ax.ticklabel_format(style="sci", scilimits=(0, 0), axis="y")
ax.set_zlabel(zlab)
elif i == 1:
ax.set_xlabel("Mean", labelpad=10)
ax.xaxis.set_major_formatter(mean_formatter)
ax.set_ylabel("Cov", labelpad=10)
ax.ticklabel_format(style="sci", scilimits=(0, 0), axis="y")
ax.set_zticklabels([])
else:
ax.set_xlabel("Mean", labelpad=10)
ax.xaxis.set_major_formatter(mean_formatter)
ax.set_yticklabels([])
ax.set_ylabel("Cov", labelpad=10)
ax.set_zlabel(zlab)
ax.view_init(*view)
else:
axes = fig.add_subplot(111, projection="3d")
axes.scatter(_df["mean"], _df["cov"], _df[zlab], c=color, s=0.2, alpha=0.6)
from matplotlib import animation
def update(i):
axes.view_init(10, i)
ani = animation.FuncAnimation(
fig, func=update, frames=100, interval=10, blit=False
)
ani.save(save_animate_path, writer="imagemagick")
return fig, axes
[docs]def plot_hvf_selection():
return