import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib import ticker
[docs]def plot_on_plate(
    data,
    value_col,
    groupby,
    ncols=4,
    plate_base=384,
    figsize=(5, 3.5),
    row_base="Row384",
    col_base="Col384",
    vmin=0,
    vmax=1,
    heatmap_kws=None,
    aggregation_func=None,
):
    """
    Plot metadata into 384 or 96 plate view (heatmap)
    Parameters
    ----------
    data
        dataframe contain all kinds of metadata
    value_col
        value to be plotted on plate view
    groupby
        groupby column, typically groupby plate id column(s) to plot each plate separately
    ncols
        number of column for axes, nrows will be calculated accordingly
    plate_base
        {384, 96} size of the plate view
    figsize
        matplotlib.Figure figsize
    vmin
        cmap vmin
    vmax
        cmap vmax
    heatmap_kws
        kws pass to sns.heatmap
    aggregation_func
        apply to reduce rows after groupby if the row is not unique
    """
    if plate_base == 384:
        plate_nrows, plate_ncols = 16, 24
    elif plate_base == 96:
        plate_nrows, plate_ncols = 8, 12
    else:
        raise ValueError(f"Plate base {plate_base} unknown")
    heatmap_data_list = []
    heatmap_names = []
    for plate, sub_df in data.groupby(groupby):
        # check if plate base are duplicated
        duplicated = sub_df[[row_base, col_base]].duplicated().sum() != 0
        if duplicated:
            if aggregation_func is None:
                raise ValueError(
                    "Row after groupby is not unique, aggregation_func can not be None"
                )
            heatmap_data = (
                sub_df.groupby([row_base, col_base])[value_col]
                .apply(aggregation_func)
                .unstack()
            )
        else:
            heatmap_data = sub_df.set_index([row_base, col_base])[value_col].unstack()
        # reindex to make sure heatmap data in the shape of plate
        heatmap_data.index = range(heatmap_data.shape[0])
        heatmap_data.columns = range(heatmap_data.shape[1])
        heatmap_data = heatmap_data.reindex(
            index=list(range(plate_nrows)), columns=list(range(plate_ncols))
        )
        heatmap_data_list.append(heatmap_data)
        if isinstance(plate, str):
            heatmap_names.append(plate)
        else:
            heatmap_names.append("\n".join(plate))
    nrows = round(len(heatmap_data_list) / ncols)
    fig, axes = plt.subplots(
        figsize=(figsize[0] * ncols, figsize[1] * nrows), ncols=ncols, nrows=nrows
    )
    fig.suptitle(
        f"{value_col} on {len(heatmap_data_list)}*{plate_base} plates \n Color Range [{vmin}, {vmax}]",
        fontsize=16,
    )
    cmap = plt.cm.viridis
    cmap.set_under(color="#440154")
    cmap.set_over(color="#FDE725")
    cmap.set_bad(color="#FFFFFF")
    if heatmap_kws is None:
        heatmap_kws = {}
    for heatmap_data, heatmap_name, ax in zip(
        heatmap_data_list, heatmap_names, np.ravel(axes)
    ):
        sns.heatmap(heatmap_data, vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, **heatmap_kws)
        ax.set(title=heatmap_name, ylim=(-0.5, plate_nrows + 0.5))
    fig.tight_layout(rect=[0, 0.05, 1, 0.95])
    return fig, axes 
[docs]def simple_violin(data, x, y, rotate_x=True):
    fig, ax = plt.subplots(figsize=(5, 3))
    sns.violinplot(x=x, y=y, data=data, ax=ax)
    if rotate_x:
        ax.tick_params(axis="x", rotation=90)
    return fig, ax 
[docs]def cutoff_vs_cell_remain(
    data, name="", xlim_quantile=(0.001, 0.999), ylim=None, bins=100
):
    xlim = tuple(np.quantile(data, xlim_quantile))
    x = np.linspace(xlim[0], xlim[1], 500)
    count_list = np.array([(data > i).sum() for i in x])
    original_total_data = data.size
    count_list = count_list / original_total_data * 100
    data = data[(data < xlim[1]) & (data > xlim[0])]
    fig, ax1 = plt.subplots()
    try:
        ax1 = sns.histplot(data, bins=bins, kde=False, ax=ax1)
    except AttributeError:
        # old seaborn version
        ax1 = sns.distplot(a=data, bins=bins, kde=False, ax=ax1)
    ax1.set_xlim(xlim)
    ax1.set_xlabel(name)
    if ylim is not None:
        ax1.set_ylim(*ylim)
    ax2 = ax1.twinx()
    ax2.plot(x, count_list, linewidth=2, linestyle="--", c="r")
    ax2.set_ylabel("% of Data Pass Filter", color="r")
    ax2.grid()
    return fig, (ax1, ax2) 
[docs]def success_vs_fail(data, filter_col, filter_cutoff, x, y, ax):
    use_data = data.copy()
    use_data["filter"] = (data[filter_col] > filter_cutoff).apply(
        lambda i: "Success" if i else "Fail"
    )
    sns.violinplot(x=x, y=y, data=use_data, hue="filter", ax=ax)
    ax.tick_params(axis="x", rotation=90)
    return ax 
[docs]def plot_dispersion(
    data,
    hue="gene_subset",
    zlab="dispersion",
    data_quantile=(0.01, 0.99),
    save_animate_path=None,
    fig_kws=None,
):
    from mpl_toolkits.mplot3d import Axes3D
    if Axes3D.__doc__:
        # touch the Axes3D to prevent ide remove it...
        pass
    @ticker.FuncFormatter
    def mean_formatter(x, pos):
        return f"{x:.1f}"
    _fig_kws = dict(figsize=(12, 4), dpi=160)
    if fig_kws is not None:
        _fig_kws.update(fig_kws)
    x = data["mean"]
    y = data["cov"]
    z = data[zlab]
    xlim = tuple(np.quantile(x, data_quantile))
    ylim = tuple(np.quantile(y, data_quantile))
    zlim = tuple(np.quantile(z, data_quantile))
    # directly apply lim on df
    _df = data[
        (x < xlim[1])
        & (x > xlim[0])
        & (y < ylim[1])
        & (y > ylim[0])
        & (z < zlim[1])
        & (z > zlim[0])
    ]
    color_dict = {True: "steelblue", False: "lightgray"}
    color = _df[hue].map(color_dict).tolist()
    fig = plt.figure(**_fig_kws)
    if save_animate_path is None:
        axes = [fig.add_subplot(int(f"13{i}"), projection="3d") for i in range(1, 4)]
        view_inits = [(10, 10), (80, 45), (10, 80)]
        for i, (ax, view) in enumerate(zip(axes, view_inits)):
            ax.scatter(_df["mean"], _df["cov"], _df[zlab], c=color, s=0.2, alpha=0.6)
            if i == 0:
                ax.set_xlabel("Mean", labelpad=10)
                ax.set_xticklabels([])
                ax.set_ylabel("Cov", labelpad=10)
                ax.ticklabel_format(style="sci", scilimits=(0, 0), axis="y")
                ax.set_zlabel(zlab)
            elif i == 1:
                ax.set_xlabel("Mean", labelpad=10)
                ax.xaxis.set_major_formatter(mean_formatter)
                ax.set_ylabel("Cov", labelpad=10)
                ax.ticklabel_format(style="sci", scilimits=(0, 0), axis="y")
                ax.set_zticklabels([])
            else:
                ax.set_xlabel("Mean", labelpad=10)
                ax.xaxis.set_major_formatter(mean_formatter)
                ax.set_yticklabels([])
                ax.set_ylabel("Cov", labelpad=10)
                ax.set_zlabel(zlab)
            ax.view_init(*view)
    else:
        axes = fig.add_subplot(111, projection="3d")
        axes.scatter(_df["mean"], _df["cov"], _df[zlab], c=color, s=0.2, alpha=0.6)
        from matplotlib import animation
        def update(i):
            axes.view_init(10, i)
        ani = animation.FuncAnimation(
            fig, func=update, frames=100, interval=10, blit=False
        )
        ani.save(save_animate_path, writer="imagemagick")
    return fig, axes 
[docs]def plot_hvf_selection():
    return