scHopfield.inference.fit_interactions
- scHopfield.inference.fit_interactions(adata: AnnData, cluster_key: str, spliced_key: str = 'Ms', velocity_key: str = 'velocity_S', degradation_key: str = 'gamma', w_threshold: float = 1e-05, w_scaffold: ndarray | None = None, scaffold_regularization: float = 1.0, reconstruction_regularization: float = 1.0, bias_regularization: float = 1.0, bias_bias: float = 0.0, only_TFs: bool = False, infer_I: bool = False, refit_gamma: bool = False, pre_initialize_W: bool = False, n_epochs: int = 1000, criterion: str = 'L2', batch_size: int = 64, device: str = 'cpu', skip_all: bool = False, learning_rate: float = 0.1, use_scheduler: bool = False, scheduler_kws: Dict | None = None, use_plateau_scheduler: bool = False, plateau_patience: int = 50, plateau_factor: float = 0.5, plateau_min_lr: float = 1e-06, get_plots: bool = False, hierarchical_pretrain: bool = False, hierarchy_keys: List[str] | None = None, hierarchy_mappings: List[Dict[str, str]] | None = None, drop_last: bool = True, balanced_sampling: bool = False, normalize_regularization: bool = False, include_neighbors: bool = False, neighbors_key: str = 'connectivities', neighbor_fraction: float = 0.0, hierarchical_scaling: bool = False, copy: bool = False) AnnData | None[source]
Infer gene regulatory network interaction matrices.
Fits interaction matrix W and bias vector I for each cluster by solving: velocity = W * sigmoid(expression) - gamma * expression + I
- Parameters:
adata (AnnData) – Annotated data object with fitted sigmoid parameters
cluster_key (str) – Key in adata.obs containing cluster annotations
spliced_key (str, optional (default: 'Ms')) – Key in adata.layers for spliced counts
velocity_key (str, optional (default: 'velocity_S')) – Key in adata.layers for RNA velocity
degradation_key (str, optional (default: 'gamma')) – Key in adata.var for degradation rates
w_threshold (float, optional (default: 1e-5)) – Threshold for pruning small interaction weights
w_scaffold (np.ndarray, optional) – Binary scaffold matrix constraining network topology
scaffold_regularization (float, optional (default: 1.0)) – Regularization strength for scaffold constraint
reconstruction_regularization (float, optional (default: 1.0)) – Regularization strength for reconstruction loss
bias_regularization (float, optional (default: 1.0)) – Regularization strength for bias vector
bias_bias (float, optional (default: 0.0)) – Additional bias term to encourage bias values (e.g., negative bias_bias encourages more positive biases)
only_TFs (bool, optional (default: False)) – If True, use masked linear layer (requires w_scaffold)
infer_I (bool, optional (default: False)) – If True, infer bias vector I in least squares
refit_gamma (bool, optional (default: False)) – If True, refit degradation rates during optimization
pre_initialize_W (bool, optional (default: False)) – If True, initialize W with least squares solution
n_epochs (int, optional (default: 1000)) – Number of training epochs
criterion (str, optional (default: 'L2')) – Loss function: ‘L1’, ‘L2’, or ‘MSE’
batch_size (int, optional (default: 64)) – Batch size for training
device (str, optional (default: 'cpu')) – Device for computation: ‘cpu’ or ‘cuda’
skip_all (bool, optional (default: False)) – If True, skip fitting on all cells combined
learning_rate (float, optional (default: 0.1)) – Initial learning rate for training
use_scheduler (bool, optional (default: False)) – If True, use StepLR learning rate scheduler
scheduler_kws (dict, optional) – Keyword arguments for StepLR scheduler
use_plateau_scheduler (bool, optional (default: False)) – If True, use ReduceLROnPlateau scheduler that decreases learning rate when the loss plateaus. This overrides use_scheduler.
plateau_patience (int, optional (default: 50)) – Number of epochs with no improvement after which learning rate will be reduced
plateau_factor (float, optional (default: 0.5)) – Factor by which the learning rate will be reduced (new_lr = lr * factor)
plateau_min_lr (float, optional (default: 1e-6)) – Minimum learning rate for plateau scheduler
get_plots (bool, optional (default: False)) – If True, show training plots
hierarchical_pretrain (bool, optional (default: False)) – If True, enable hierarchical pretraining. First trains on all cells, then uses those parameters to initialize cluster-specific training. If hierarchy_keys is provided, trains through multiple levels.
hierarchy_keys (list of str, optional) – List of obs keys from coarse to fine clustering (e.g., [‘lineage’, ‘cell_type’]). Only used if hierarchical_pretrain=True. If None, uses simple two-level hierarchy: ‘all’ → cluster_key.
hierarchy_mappings (list of dict, optional) – List of mappings between consecutive hierarchy levels. Each mapping is {fine_cluster: coarse_cluster}. Must have len(hierarchy_keys) - 1 elements. Example: [{‘T_cell’: ‘immune’, ‘B_cell’: ‘immune’, ‘Fibroblast’: ‘stromal’}]
drop_last (bool, optional (default: True)) – If True, drop the last incomplete batch to ensure consistent batch sizes. This reduces gradient variance from small tail-end batches.
balanced_sampling (bool, optional (default: False)) – If True, use weighted sampling to balance cluster representation when training on multiple clusters (e.g., when training ‘all’ during hierarchical pretraining). Requires hierarchical_pretrain=True.
normalize_regularization (bool, optional (default: False)) – If True, normalize scaffold and bias regularization losses by batch size. This keeps regularization balanced with reconstruction loss when batch sizes vary. Alternative to drop_last for handling batch inconsistency.
include_neighbors (bool, optional (default: False)) – If True, include neighboring cells (from any cluster) when training cluster-specific models. Neighbors are determined from the connectivity matrix. Only applies to non-‘all’ clusters.
neighbors_key (str, optional (default: 'connectivities')) – Key in adata.obsp containing the cell-cell connectivity matrix. If not found, neighbors will be computed using scanpy.
neighbor_fraction (float, optional (default: 0.0)) – Fraction of each training batch that should come from neighboring cells (cells not in the cluster but connected via the neighbor graph). Only applies when include_neighbors=True. Value must be in [0.0, 1.0). Example: 0.2 means 20% of each batch are neighbors, 80% cluster cells.
hierarchical_scaling (bool, optional (default: False)) – If True and hierarchical_pretrain=True, use half epochs for pretraining levels (all levels except the finest) and adjust initial learning rate based on parent’s final learning rate. Child levels start with LR exponent = parent_final_lr_exponent / 2 (e.g., parent ends at 1e-8, child starts at 1e-4).
copy (bool, optional (default: False)) – If True, return a copy instead of modifying in-place
- Returns:
Returns adata if copy=True, otherwise None. Adds to adata: - adata.varp[f’W_{cluster}’]: interaction matrix for each cluster - adata.var[f’I_{cluster}’]: bias vector for each cluster - adata.var[f’gamma_{cluster}’]: refitted gamma if refit_gamma=True - adata.uns[‘scHopfield’][‘models’][cluster]: trained models if w_scaffold is provided
- Return type:
AnnData or None