Source code for scHopfield.tools.correlation

"""Correlation analysis between energies, genes, and cell types."""

import numpy as np
import pandas as pd
import itertools
from typing import Optional
from anndata import AnnData
import hoggorm as ho

from .._utils.io import get_matrix, to_numpy, get_genes_used


[docs] def energy_gene_correlation( adata: AnnData, spliced_key: str = 'Ms', cluster_key: str = 'cell_type', copy: bool = False ) -> Optional[AnnData]: """ Correlate energies with gene expression. Computes Pearson correlation between energy values and each gene's expression for each cluster. Adapted from Landscape.energy_genes_correlation. Parameters ---------- adata : AnnData Annotated data object with computed energies spliced_key : str, optional (default: 'Ms') Key in adata.layers for expression data copy : bool, optional (default: False) If True, return a copy instead of modifying in-place Returns ------- AnnData or None Returns adata if copy=True, otherwise None. Adds to adata.var for each cluster and energy type: - 'correlation_total_{cluster}' - 'correlation_interaction_{cluster}' - 'correlation_degradation_{cluster}' - 'correlation_bias_{cluster}' """ adata = adata.copy() if copy else adata genes = get_genes_used(adata) gene_names = adata.var.index[genes] # Get clusters clusters = adata.obs[cluster_key].unique() # Get energies from shared columns (already computed with cluster-specific parameters) energies = np.zeros((4, adata.n_obs), dtype=np.float32) energies[0, :] = adata.obs['energy_total'].values energies[1, :] = adata.obs['energy_interaction'].values energies[2, :] = adata.obs['energy_degradation'].values energies[3, :] = adata.obs['energy_bias'].values for cluster in clusters: if cluster == 'all': continue # Get cluster cells cells = (adata.obs[cluster_key] == cluster).values # Get expression for this cluster X = to_numpy(get_matrix(adata, spliced_key, genes=genes)[cells].T) # Compute correlations for this cluster correlations = np.nan_to_num(np.corrcoef(np.vstack((energies[:, cells], X)))[:4, 4:]) # Initialize columns if not present for i, etype in enumerate(['total', 'interaction', 'degradation', 'bias']): col = f'correlation_{etype}_{cluster}' if col not in adata.var: adata.var[col] = 0.0 adata.var.loc[gene_names, col] = correlations[i, :] # Compute for 'all' cells X_all = to_numpy(get_matrix(adata, spliced_key, genes=genes).T) correlations_all = np.nan_to_num(np.corrcoef(np.vstack((energies, X_all)))[:4, 4:]) for i, etype in enumerate(['total', 'interaction', 'degradation', 'bias']): col = f'correlation_{etype}_all' if col not in adata.var: adata.var[col] = 0.0 adata.var.loc[gene_names, col] = correlations_all[i, :] return adata if copy else None
[docs] def celltype_correlation( adata: AnnData, spliced_key: str = 'Ms', cluster_key: str = 'cell_type', modified: bool = True, all_genes: bool = False, copy: bool = False ) -> Optional[AnnData]: """ Compute correlation between cell types based on gene expression. Uses RV coefficient to measure similarity between cell type expression profiles. Adapted from Landscape.celltype_correlation. Parameters ---------- adata : AnnData Annotated data object spliced_key : str, optional (default: 'Ms') Key in adata.layers for expression data modified : bool, optional (default: True) If True, use modified RV2 coefficient; if False, use RV coefficient all_genes : bool, optional (default: False) If True, use all genes; if False, use only genes from analysis copy : bool, optional (default: False) If True, return a copy instead of modifying in-place Returns ------- AnnData or None Returns adata if copy=True, otherwise None. Adds to adata.uns['scHopfield']: - 'celltype_correlation': DataFrame with pairwise correlations """ adata = adata.copy() if copy else adata keys = adata.obs[cluster_key].unique() corr_f = ho.mat_corr_coeff.RV2coeff if modified else ho.mat_corr_coeff.RVcoeff rv = pd.DataFrame(index=keys, columns=keys, data=1.0) genes = None if all_genes else get_genes_used(adata) counts = get_matrix(adata, spliced_key, genes=genes) for k1, k2 in itertools.combinations(keys, 2): expr_k1 = to_numpy(counts[(adata.obs[cluster_key] == k1).values]) expr_k2 = to_numpy(counts[(adata.obs[cluster_key] == k2).values]) rv.loc[k1, k2] = corr_f([expr_k1.T, expr_k2.T])[0, 1] rv.loc[k2, k1] = rv.loc[k1, k2] adata.uns['scHopfield']['celltype_correlation'] = rv return adata if copy else None
[docs] def future_celltype_correlation( adata: AnnData, spliced_key: str = 'Ms', cluster_key: str = 'cell_type', modified: bool = True, copy: bool = False ) -> Optional[AnnData]: """ Compute correlation between cell types based on predicted future states. Adapted from Landscape.future_celltype_correlation. Parameters ---------- adata : AnnData Annotated data object with fitted interactions spliced_key : str, optional (default: 'Ms') Key in adata.layers for expression data modified : bool, optional (default: True) If True, use modified RV2 coefficient copy : bool, optional (default: False) If True, return a copy instead of modifying in-place Returns ------- AnnData or None Returns adata if copy=True, otherwise None. Adds to adata.uns['scHopfield']: - 'future_celltype_correlation': DataFrame with pairwise correlations """ adata = adata.copy() if copy else adata genes = get_genes_used(adata) keys = adata.obs[cluster_key].unique() corr_f = ho.mat_corr_coeff.RV2coeff if modified else ho.mat_corr_coeff.RVcoeff rv = pd.DataFrame(index=keys, columns=keys, data=1.0) counts = get_matrix(adata, spliced_key, genes=genes) threshold = adata.var['sigmoid_threshold'].values[genes] exponent = adata.var['sigmoid_exponent'].values[genes] for k1, k2 in itertools.combinations(keys, 2): from .._utils.math import sigmoid counts_k1 = to_numpy(counts[(adata.obs[cluster_key] == k1).values]) counts_k2 = to_numpy(counts[(adata.obs[cluster_key] == k2).values]) sig_k1 = sigmoid(counts_k1, threshold[None, :], exponent[None, :]) sig_k2 = sigmoid(counts_k2, threshold[None, :], exponent[None, :]) W_k1 = adata.varp[f'W_{k1}'] W_k2 = adata.varp[f'W_{k2}'] future_k1 = (W_k1 @ sig_k1.T) future_k2 = (W_k2 @ sig_k2.T) rv.loc[k1, k2] = corr_f([future_k1, future_k2])[0, 1] rv.loc[k2, k1] = rv.loc[k1, k2] adata.uns['scHopfield']['future_celltype_correlation'] = rv return adata if copy else None
[docs] def get_correlation_table( adata: AnnData, cluster_key: str = 'cell_type', energy_type: str = 'total', n_top_genes: int = 20, order: Optional[list] = None ) -> pd.DataFrame: """ Get correlation table with top genes per cluster. Creates a formatted table showing the top N genes correlated with energy for each cluster. Parameters ---------- adata : AnnData Annotated data object with computed energy-gene correlations cluster_key : str, optional (default: 'cell_type') Key in adata.obs for cluster labels energy_type : str, optional (default: 'total') Type of energy correlation: 'total', 'interaction', 'degradation', or 'bias' n_top_genes : int, optional (default: 20) Number of top correlated genes to show per cluster order : list, optional Order of clusters to display. If None, uses all unique clusters Returns ------- pd.DataFrame DataFrame with MultiIndex columns (cluster, ['Gene', 'Correlation']) showing top correlated genes for each cluster """ genes = get_genes_used(adata) gene_names = adata.var.index[genes] if order is None: order = adata.obs[cluster_key].unique().tolist() # Check that correlations exist test_col = f'correlation_{energy_type}_{order[0]}' if test_col not in adata.var.columns: raise ValueError( "No correlation data found. Please run sch.tl.energy_gene_correlation() first." ) # Create DataFrame with MultiIndex columns df = pd.DataFrame( index=range(n_top_genes), columns=pd.MultiIndex.from_product([order, ['Gene', 'Correlation']]) ) for cluster in order: corr_col = f'correlation_{energy_type}_{cluster}' if corr_col not in adata.var.columns: print(f"Warning: No correlation data for cluster '{cluster}', skipping...") continue # Get correlations for this cluster corrs = adata.var[corr_col].values[genes] # Sort by correlation (descending) and get top N indices = np.argsort(corrs)[::-1][:n_top_genes] top_genes = gene_names[indices] top_corrs = corrs[indices] # Fill in the DataFrame df[(cluster, 'Gene')] = top_genes.values df[(cluster, 'Correlation')] = top_corrs return df