Source code for scHopfield.tools.energy

"""Energy landscape computation."""

import numpy as np
from typing import Optional
from anndata import AnnData

from .._utils.math import sigmoid, int_sig_act_inv
from .._utils.io import get_matrix, get_genes_used, ensure_sigmoid_layer


[docs] def compute_energies( adata: AnnData, spliced_key: str = 'Ms', degradation_key: str = 'gamma', cluster_key: str = 'cell_type', copy: bool = False ) -> Optional[AnnData]: """ Calculate energy landscapes for all clusters. Computes total energy and its components (interaction, degradation, bias) for each cell based on the inferred gene regulatory network. Parameters ---------- adata : AnnData Annotated data object with fitted interactions spliced_key : str, optional (default: 'Ms') Key in adata.layers for spliced counts degradation_key : str, optional (default: 'gamma') Key in adata.var for degradation rates copy : bool, optional (default: False) If True, return a copy instead of modifying in-place Returns ------- AnnData or None Returns adata if copy=True, otherwise None. Adds to adata.obs: - 'energy_total' - 'energy_interaction' - 'energy_degradation' - 'energy_bias' Each cell's energy is computed using cluster-specific parameters. Notes ----- Energy formula: E = -0.5 * s^T W s + gamma * integral(sigmoid^-1) - I^T s """ adata = adata.copy() if copy else adata ensure_sigmoid_layer(adata, spliced_key) # Find all clusters that have been fitted clusters = adata.obs[cluster_key].unique() # Initialize columns once for all cells if 'energy_total' not in adata.obs: adata.obs['energy_total'] = 0.0 adata.obs['energy_interaction'] = 0.0 adata.obs['energy_degradation'] = 0.0 adata.obs['energy_bias'] = 0.0 for cluster in clusters: # Get cluster indices if cluster == 'all': idx = np.ones(adata.n_obs, dtype=bool) else: idx = (adata.obs[cluster_key] == cluster).values # Compute energy components for this cluster's cells using cluster-specific parameters e_int = _interaction_energy(adata, cluster, cluster_key) e_deg = _degradation_energy(adata, cluster, spliced_key, degradation_key, cluster_key) e_bias = _bias_energy(adata, cluster, spliced_key, cluster_key, x=None) # Store energies in shared columns adata.obs.loc[idx, 'energy_total'] = e_int + e_deg + e_bias adata.obs.loc[idx, 'energy_interaction'] = e_int adata.obs.loc[idx, 'energy_degradation'] = e_deg adata.obs.loc[idx, 'energy_bias'] = e_bias return adata if copy else None
def _interaction_energy( adata: AnnData, cluster: str, cluster_key: str, x: Optional[np.ndarray] = None ) -> np.ndarray: """ Calculate interaction energy component: -0.5 * s^T W s. Computes the energy contribution from gene-gene interactions using the cluster-specific interaction matrix W and sigmoid activations. Parameters ---------- adata : AnnData Annotated data object with fitted parameters cluster : str Cluster name to use for interaction matrix cluster_key : str Key in adata.obs for cluster labels x : np.ndarray, optional Optional expression data. If None, uses stored sigmoid values Returns ------- np.ndarray Array of interaction energies for each cell """ genes = get_genes_used(adata) # Get sigmoid activations if x is not None: threshold = adata.var['sigmoid_threshold'].values[genes] exponent = adata.var['sigmoid_exponent'].values[genes] sig = np.nan_to_num(sigmoid(x, threshold[None, :], exponent[None, :])) else: if cluster == 'all': idx = slice(None) else: idx = (adata.obs[cluster_key] == cluster).values sig = get_matrix(adata, 'sigmoid', genes=genes)[idx] # Get interaction matrix W = adata.varp[f'W_{cluster}'] # Calculate interaction energy interaction_energy = -0.5 * np.sum((sig @ W.T) * sig, axis=1) return interaction_energy def _degradation_energy( adata: AnnData, cluster: str, spliced_key: str, degradation_key: str, cluster_key: str = 'cell_type', x: Optional[np.ndarray] = None ) -> np.ndarray: """ Calculate degradation energy using integral of inverse sigmoid. Computes the energy contribution from mRNA degradation by integrating the inverse sigmoid function, weighted by degradation rates. Parameters ---------- adata : AnnData Annotated data object with fitted parameters cluster : str Cluster name to use for parameters spliced_key : str Key in adata.layers for spliced counts degradation_key : str Key in adata.var for degradation rates cluster_key : str, optional (default: 'cell_type') Key in adata.obs for cluster labels x : np.ndarray, optional Optional expression data. If None, uses stored sigmoid values Returns ------- np.ndarray Array of degradation energies for each cell """ genes = get_genes_used(adata) # Get degradation rates g = adata.var[degradation_key].values[genes] # Get sigmoid parameters threshold = adata.var['sigmoid_threshold'].values[genes] exponent = adata.var['sigmoid_exponent'].values[genes] # Get sigmoid activations if x is not None: sig = np.nan_to_num(sigmoid(x, threshold[None, :], exponent[None, :])) else: idx = (adata.obs[cluster_key] == cluster).values sig = get_matrix(adata, 'sigmoid', genes=genes)[idx] # Compute integral integral = int_sig_act_inv(sig, threshold, exponent) degradation_energy = np.sum(g[None, :] * integral, axis=1) return degradation_energy def _bias_energy( adata: AnnData, cluster: str, spliced_key: str, cluster_key: str = 'cell_type', x: Optional[np.ndarray] = None ) -> np.ndarray: """ Calculate bias energy component: -I^T s. Computes the energy contribution from external inputs or biases using the cluster-specific bias vector I and sigmoid activations. Parameters ---------- adata : AnnData Annotated data object with fitted parameters cluster : str Cluster name to use for bias vector spliced_key : str Key in adata.layers for spliced counts cluster_key : str, optional (default: 'cell_type') Key in adata.obs for cluster labels x : np.ndarray, optional Optional expression data. If None, uses stored sigmoid values Returns ------- np.ndarray Array of bias energies for each cell """ genes = get_genes_used(adata) # Get sigmoid activations if x is not None: threshold = adata.var['sigmoid_threshold'].values[genes] exponent = adata.var['sigmoid_exponent'].values[genes] sig = np.nan_to_num(sigmoid(x, threshold[None, :], exponent[None, :])) else: idx = (adata.obs[cluster_key] == cluster).values sig = get_matrix(adata, 'sigmoid', genes=genes)[idx] # Get bias vector bias_vector = adata.var[f'I_{cluster}'].values[genes] # Calculate bias energy bias_energy = -np.sum(bias_vector[None, :] * sig, axis=1) return bias_energy
[docs] def decompose_degradation_energy( adata: AnnData, cluster: str, spliced_key: str = 'Ms', degradation_key: str = 'gamma', cluster_key: str = 'cell_type', x: Optional[np.ndarray] = None ) -> np.ndarray: """ Calculate gene-wise degradation energy decomposition. Computes the degradation energy contribution for each gene separately, allowing analysis of which genes contribute most to the total degradation energy. Parameters ---------- adata : AnnData Annotated data object with fitted parameters cluster : str Cluster name to use for parameters spliced_key : str, optional (default: 'Ms') Key in adata.layers for spliced counts degradation_key : str, optional (default: 'gamma') Key in adata.var for degradation rates cluster_key : str, optional (default: 'cell_type') Key in adata.obs for cluster labels x : np.ndarray, optional Optional expression data. If None, uses stored sigmoid values Returns ------- np.ndarray Array of shape (n_cells, n_genes) with degradation energy per gene """ if x is None: ensure_sigmoid_layer(adata, spliced_key) genes = get_genes_used(adata) # Get degradation rates gamma_col = f'gamma_{cluster}' if gamma_col in adata.var: g = adata.var[gamma_col].values[genes] else: g = adata.var[degradation_key].values[genes] # Get sigmoid if x is not None: threshold = adata.var['sigmoid_threshold'].values[genes] exponent = adata.var['sigmoid_exponent'].values[genes] sig = np.nan_to_num(sigmoid(x, threshold[None, :], exponent[None, :])) else: if cluster == 'all': idx = slice(None) else: idx = (adata.obs[cluster_key] == cluster).values sig = get_matrix(adata, 'sigmoid', genes=genes)[idx] threshold = adata.var['sigmoid_threshold'].values[genes] exponent = adata.var['sigmoid_exponent'].values[genes] integral = int_sig_act_inv(sig, threshold, exponent) return g[None, :] * integral
[docs] def decompose_bias_energy( adata: AnnData, cluster: str, spliced_key: str = 'Ms', cluster_key: str = 'cell_type', x: Optional[np.ndarray] = None ) -> np.ndarray: """ Calculate gene-wise bias energy decomposition. Computes the bias energy contribution for each gene separately, allowing analysis of which genes contribute most to the total bias energy. Parameters ---------- adata : AnnData Annotated data object with fitted parameters cluster : str Cluster name to use for bias vector spliced_key : str, optional (default: 'Ms') Key in adata.layers for spliced counts cluster_key : str, optional (default: 'cell_type') Key in adata.obs for cluster labels x : np.ndarray, optional Optional expression data. If None, uses stored sigmoid values Returns ------- np.ndarray Array of shape (n_cells, n_genes) with bias energy per gene """ if x is None: ensure_sigmoid_layer(adata, spliced_key) genes = get_genes_used(adata) # Get sigmoid if x is not None: threshold = adata.var['sigmoid_threshold'].values[genes] exponent = adata.var['sigmoid_exponent'].values[genes] sig = np.nan_to_num(sigmoid(x, threshold[None, :], exponent[None, :])) else: if cluster == 'all': idx = slice(None) else: idx = (adata.obs[cluster_key] == cluster).values sig = get_matrix(adata, 'sigmoid', genes=genes)[idx] bias_vector = adata.var[f'I_{cluster}'].values[genes] return -bias_vector[None, :] * sig
[docs] def decompose_interaction_energy( adata: AnnData, cluster: str, side: str = 'in', spliced_key: str = 'Ms', cluster_key: str = 'cell_type', x: Optional[np.ndarray] = None ) -> np.ndarray: """ Calculate gene-wise interaction energy. Adapted from Landscape.interaction_energy_decomposed. Parameters ---------- side : str, optional (default: 'in') 'in' for incoming interactions, 'out' for outgoing interactions Returns ------- np.ndarray Array of shape (n_cells, n_genes) with interaction energy per gene """ if x is None: ensure_sigmoid_layer(adata, spliced_key) genes = get_genes_used(adata) # Get sigmoid if x is not None: threshold = adata.var['sigmoid_threshold'].values[genes] exponent = adata.var['sigmoid_exponent'].values[genes] sig = np.nan_to_num(sigmoid(x, threshold[None, :], exponent[None, :])) else: if cluster == 'all': idx = slice(None) else: idx = (adata.obs[cluster_key] == cluster).values sig = get_matrix(adata, 'sigmoid', genes=genes)[idx] W = adata.varp[f'W_{cluster}'] if side == 'out': return -0.5 * (sig @ W.T) * sig else: # 'in' return -0.5 * (sig @ W) * sig