Source code for ALLCools.pseudo_cell.pseudo_cell_knn

import numpy as np
import warnings
from .pseudo_cell_kmeans import _merge_pseudo_cell


[docs]class ContractedExamplerSampler: def __init__(self, data, n_components=30, normalize=False): self.data = data if self.data.shape[1] > n_components: from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler data_std = StandardScaler().fit_transform(self.data) self.pca = PCA(n_components).fit_transform(data_std) else: self.pca = np.array(data) if normalize: # from sklearn.preprocessing import MaxAbsScaler # self.pca = MaxAbsScaler().fit_transform(self.pca) raise NotImplementedError from pynndescent import NNDescent self.ann_index = NNDescent(self.pca)
[docs] def _sample_pulp_dist(self, n_kernels, pulp_size): kernels = np.random.choice(len(self.pca), n_kernels) pulps, dists = self.ann_index.query(self.pca[kernels], pulp_size) return dists
[docs] def _select_dist_thresh( self, pulp_size, n_tests=100, pulp_thicken_ratio=1.2, robust_quantile=0.9 ): dists = self._sample_pulp_dist(n_tests, int(pulp_thicken_ratio * pulp_size)) return np.quantile(dists, robust_quantile)
[docs] def _sample_fruit( self, n_kernels, pulp_size, max_iters, dist_thresh=None, ovlp_tol=0.2, min_pulp_size=None, k=100, ): import scipy.spatial.distance as ssd _dist_thresh = ( dist_thresh if dist_thresh is not None else _select_dist_thresh(pulp_size) ) kernels = [] pulps = [] unused = set(range(len(self.pca))) while max_iters > 0 and len(kernels) < n_kernels: max_iters -= 1 if len(kernels) > 0: kernel_cands = np.random.choice(list(unused), k) dists = ssd.cdist(self.pca[kernel_cands], self.pca[kernels]) dists = dists.min(1) dists, kernel_cands = zip(*sorted(zip(dists, kernel_cands))) kernel_cands = np.array(kernel_cands) # dists = np.array(dists) kernel_cands = kernel_cands[dists > _dist_thresh] else: kernel_cands = [np.random.choice(list(unused))] for kernel in kernel_cands: pulp, dists = self.ann_index.query([self.pca[kernel]], pulp_size) pulp = pulp.flatten() if ( (dist_thresh is None) or ((min_pulp_size is None) and (dists < dist_thresh).all()) or ( (min_pulp_size is not None) and (dists < dist_thresh).sum() >= min_pulp_size ) ): if len(set(pulp) - set(unused)) / len(pulp) <= ovlp_tol: kernels.append(kernel) pulps.append(pulp) unused -= set(pulp) break return kernels, pulps
[docs] def sample_contracted_examplers( self, n_examplers, n_neighbors, min_n_neighbors=None, ovlp_tol=0, dist_thresh=None, max_iters=100, ): if dist_thresh is None: dist_thresh = self._select_dist_thresh(n_neighbors) examplers, neighbors = self._sample_fruit( n_kernels=n_examplers, pulp_size=n_neighbors, max_iters=max_iters, dist_thresh=dist_thresh, ovlp_tol=ovlp_tol, min_pulp_size=min_n_neighbors, k=100, ) return examplers, neighbors
[docs]def sample_pseudo_cells( cell_meta, cluster_col, coords, target_pseudo_size, min_pseudo_size=None, ignore_small_cluster=False, n_components=30, pseudo_ovlp=0, ): _cell_meta = cell_meta[[cluster_col]].copy() index_name = _cell_meta.index.name _cell_meta = _cell_meta.reset_index() small_cluster_flags = [] for c, cmeta in _cell_meta.groupby(cluster_col, as_index=False): n_pseudos = len(cmeta) // target_pseudo_size if n_pseudos == 0: if ignore_small_cluster: continue else: warnings.warn( f'Size of cluster "{c}" is smaller than target pseudo-cell size.' ) small_cluster_flags.append(True) pseudo_centers = [0] pseudo_groups = [list(range(cmeta.shape[0]))] else: small_cluster_flags.append(False) sampler = ContractedExamplerSampler(coords[cmeta.index], n_components) pseudo_centers, pseudo_groups = sampler.sample_contracted_examplers( n_pseudos, target_pseudo_size, min_pseudo_size, ovlp_tol=pseudo_ovlp ) for i, (pcenter, pgroup) in enumerate(zip(pseudo_centers, pseudo_groups)): _cell_meta.loc[cmeta.iloc[pcenter].name, "pseudo_center"] = f"{c}::{i}" _cell_meta.loc[cmeta.iloc[pgroup].index, "pseudo_cell"] = f"{c}::{i}" _cell_meta = _cell_meta.set_index(index_name) stats = _cell_meta.copy() stats.index.name = "total_cells" stats = stats.reset_index().groupby(cluster_col, as_index=False).count() stats["cover_ratio"] = stats["pseudo_cell"] / stats["total_cells"] stats.columns = [ cluster_col, "total_cells", "pseudo_cells", "covered_cells", "cover_ratio", ] # stats.index = ['total_cells', 'pseudo_cells', 'covered_cells'] # stats['pseud_yield'] = stats['covered_cells']/stats['total_cells'] return _cell_meta, stats
# a wrapper of sample_pseudo_cells to use adata as input
[docs]def generate_pseudo_cells( adata, cluster_col="leiden", obsm="X_pca", target_pseudo_size=100, min_pseudo_size=None, ignore_small_cluster=False, n_components=None, aggregate_func="downsample", pseudo_ovlp=0, ): if n_components is None: n_components = adata.obsm[obsm].shape[1] if min_pseudo_size is None: min_pseudo_size = 1 # determine cell group cell_group, stats = sample_pseudo_cells( cell_meta=adata.obs, cluster_col=cluster_col, coords=adata.obsm[obsm], target_pseudo_size=target_pseudo_size, min_pseudo_size=min_pseudo_size, ignore_small_cluster=ignore_small_cluster, n_components=n_components, pseudo_ovlp=pseudo_ovlp, ) adata.obs["cell_group"] = cell_group["pseudo_cell"] pseudo_cell_adata = _merge_pseudo_cell(adata=adata, aggregate_func=aggregate_func) pseudo_cell_adata.obs[cluster_col] = pseudo_cell_adata.obs_names.str.split( "::" ).str[0] return pseudo_cell_adata