Source code for ALLCools.clustering.ClusterMerging

import numpy as np
import scipy
import scipy.cluster.hierarchy as sch
from .dmg import PairwiseDMG
[docs]class ClusterMerge: def __init__(self, merge_criterion, stop_criterion=None, stop_clusters=-1, n_cells=200, metric='euclidean',method='average', label_concat_str='::'): self.data_for_tree = None self.cell_to_type = None self.gene_mcds = None self.n_cells = n_cells self.metric = metric self.method = method self.label_concat_str = label_concat_str self.merge_criterion = merge_criterion self.stop_criterion = stop_criterion self.stop_clusters = stop_clusters self.curr_tree = None self.curr_mean = None self.merge_evidences = {}
[docs] def _construct_tree(self): cells = self.cell_to_type.to_frame().groupby(self.cell_to_type.name)\ .apply(lambda x: x.sample(self.n_cells, replace=True)).droplevel(0).sort_index().index self.curr_mean = self.data_for_tree.loc[cells].groupby(self.cell_to_type[cells]).mean() pdist = scipy.spatial.distance.pdist(self.curr_mean, metric=self.metric) Z = sch.linkage(pdist, metric=self.metric, method=self.method) self.curr_tree = sch.to_tree(Z)
@staticmethod
[docs] def _traverse(node, call_back): if node.is_leaf(): return elif node.left.is_leaf() and node.right.is_leaf(): call_back(node) if not node.left.is_leaf(): ClusterMerge._traverse(node.left, call_back) if not node.right.is_leaf(): ClusterMerge._traverse(node.right, call_back)
[docs] def _merge_pair(self, pair, concat_str='::'): left_id, right_id = pair left_lbl, right_lbl = self.curr_mean.iloc[[left_id, right_id]].index print('checking',left_lbl, right_lbl)#TODO pair_cells = self.cell_to_type[self.cell_to_type.isin([left_lbl, right_lbl])].index pair_cell_types = self.cell_to_type.loc[pair_cells] pair_mcds = self.gene_mcds.sel(cell = pair_cells) pair_mcds.load() separable, evidence, *_ = self.merge_criterion.predict((left_lbl, right_lbl), pair_cells, pair_cell_types, pair_mcds) mergeable = not separable if mergeable: self.cell_to_type.loc[pair_cells] = f'{left_lbl}{concat_str}{right_lbl}' # self.merge_evidences[left_lbl, right_lbl] = evidence # self.merge_evidences[right_lbl, left_lbl] = evidence self.merge_evidences[tuple(sorted([left_lbl, right_lbl]))] = evidence print(left_lbl, right_lbl, 'merged')#TODO return mergeable
[docs] def fit_predict(self, data_for_tree, cell_to_type, gene_mcds): self.data_for_tree = data_for_tree self.cell_to_type = cell_to_type self.gene_mcds = gene_mcds count = 0 while True: count += 1 self._construct_tree() pairs = [] ClusterMerge._traverse(self.curr_tree, lambda x: pairs.append((x.left.id,x.right.id))) rlt = list(map(self._merge_pair, pairs)) print('round',count, 'merged', sum(rlt)) if self.stop_clusters>0 and len(self.cell_to_type.unique())<=self.stop_clusters: print()#TODO break if sum(rlt)==0: print()#TODO break if self.stop_criterion is not None \ and self.stop_criterion(self.data_for_tree, self.cell_to_type, self.gene_mcds): print()#TODO break return self.cell_to_type, self.merge_evidences
[docs]class PairwiseDMGCriterion: def __init__(self, max_cell_per_group = 100, top_n_markers = 5, adj_p_cutoff = 0.001, delta_rate_cutoff = 0.3, auroc_cutoff = 0.85, use_modality = 'either', random_state = 0, n_jobs = 10, verbose = False, ): self.agg = {'either':np.logical_or,'both':np.logical_and}[use_modality] self.pwdmg = PairwiseDMG(max_cell_per_group = max_cell_per_group, top_n=top_n_markers, adj_p_cutoff=adj_p_cutoff, delta_rate_cutoff=delta_rate_cutoff, auroc_cutoff=auroc_cutoff, random_state=0, n_jobs=n_jobs, verbose=False,) # self.max_cell_per_group = max_cell_per_group self.top_n_markers = top_n_markers # self.adj_p_cutoff = adj_p_cutoff # self.delta_rate_cutoff = delta_rate_cutoff # self.auroc_cutoff = auroc_cutoff self.use_modality = use_modality # self.n_jobs = n_jobs
[docs] def predict(self, pair_labels, pair_cells, pair_cell_types, pair_mcds, da_name = 'gene_da_frac'): mc_types = ['CHN','CGN'] separable = {x:False for x in mc_types} evidence = {} for mc_type in mc_types: self.pwdmg.fit_predict(x=pair_mcds[da_name].sel(mc_type=mc_type), groups=pair_cell_types) evidence[mc_type] = self.pwdmg.dmg_table if len(self.pwdmg.dmg_table)>=self.top_n_markers: separable[mc_type] = True if self.use_modality=='either': break return self.agg(*separable.values()), evidence