Source code for scHopfield.inference.interactions

"""Inference of gene regulatory network interactions."""

import math
import numpy as np
import torch
from torch.utils.data import WeightedRandomSampler
from typing import Optional, List, Dict, Tuple
from anndata import AnnData

from .optimizer import ScaffoldOptimizer
from .datasets import CustomDataset
from .._utils.io import get_matrix, to_numpy, get_genes_used


def _build_hierarchy_levels(
    adata: AnnData,
    cluster_key: str,
    hierarchy_keys: Optional[List[str]],
    hierarchy_mappings: Optional[List[Dict[str, str]]]
) -> List[Tuple[str, List[str], Optional[Dict[str, str]]]]:
    """
    Build the list of training levels from coarse to fine.

    Parameters
    ----------
    adata : AnnData
        Annotated data object
    cluster_key : str
        Key for the finest level clustering (used if hierarchy_keys is None)
    hierarchy_keys : list of str, optional
        List of obs keys from coarse to fine (e.g., ['lineage', 'cell_type'])
    hierarchy_mappings : list of dict, optional
        List of {fine_cluster: coarse_cluster} mappings between consecutive levels

    Returns
    -------
    List of tuples: (obs_key, clusters_at_this_level, mapping_to_parent)
    The first level is always 'all' with mapping None.
    """
    levels = [('all', ['all'], None)]

    if hierarchy_keys is None:
        # Simple case: all → cluster_key
        clusters = list(adata.obs[cluster_key].unique())
        mapping = {c: 'all' for c in clusters}
        levels.append((cluster_key, clusters, mapping))
    else:
        # Multi-level hierarchy
        if hierarchy_mappings is None or len(hierarchy_mappings) != len(hierarchy_keys) - 1:
            raise ValueError(
                f"hierarchy_mappings must have {len(hierarchy_keys) - 1} elements "
                f"(one fewer than hierarchy_keys), got {len(hierarchy_mappings) if hierarchy_mappings else 0}"
            )

        # First level after 'all' - the coarsest clustering
        coarse_key = hierarchy_keys[0]
        coarse_clusters = list(adata.obs[coarse_key].unique())
        coarse_mapping = {c: 'all' for c in coarse_clusters}
        levels.append((coarse_key, coarse_clusters, coarse_mapping))

        # Subsequent levels
        for i, fine_key in enumerate(hierarchy_keys[1:]):
            fine_clusters = list(adata.obs[fine_key].unique())
            fine_to_coarse = hierarchy_mappings[i]

            # Validate mapping
            for fc in fine_clusters:
                if fc not in fine_to_coarse:
                    raise ValueError(
                        f"Fine cluster '{fc}' from '{fine_key}' not found in hierarchy_mappings[{i}]"
                    )

            levels.append((fine_key, fine_clusters, fine_to_coarse))

    return levels


def _get_cluster_with_neighbors(adata, cluster_idx, neighbors_key='connectivities'):
    """
    Expand cluster selection to include neighboring cells.

    Parameters
    ----------
    adata : AnnData
        Annotated data object
    cluster_idx : np.ndarray
        Boolean array of cells in the cluster
    neighbors_key : str
        Key in adata.obsp for connectivity matrix

    Returns
    -------
    tuple
        (expanded_idx, is_neighbor) where:
        - expanded_idx: Boolean array including cluster cells and neighbors
        - is_neighbor: Boolean array (same size as expanded_idx.sum()) where
          True indicates the cell is a neighbor, False indicates cluster member
    """
    import scanpy as sc
    from scipy import sparse

    # Get or compute connectivity matrix
    if neighbors_key not in adata.obsp:
        print(f"  Computing neighbors ('{neighbors_key}' not found in adata.obsp)")
        sc.pp.neighbors(adata)

    conn = adata.obsp[neighbors_key]

    # Get indices of cluster cells
    cluster_cell_indices = np.where(cluster_idx)[0]

    # Find all neighbors of cluster cells
    if sparse.issparse(conn):
        # Get rows for cluster cells, find non-zero columns (neighbors)
        neighbor_rows = conn[cluster_cell_indices, :]
        neighbor_indices = set(neighbor_rows.nonzero()[1])
    else:
        neighbor_indices = set()
        for i in cluster_cell_indices:
            neighbor_indices.update(np.where(conn[i, :] > 0)[0])

    # Create expanded boolean mask
    expanded_idx = cluster_idx.copy()
    for ni in neighbor_indices:
        expanded_idx[ni] = True

    n_original = cluster_idx.sum()
    n_expanded = expanded_idx.sum()
    n_neighbors = n_expanded - n_original
    print(f"  Including {n_neighbors} neighboring cells ({n_original} cluster + {n_neighbors} neighbors = {n_expanded} total)")

    # Create is_neighbor mask for the expanded selection
    # Mark which indices in the expanded set are neighbors
    expanded_cell_indices = np.where(expanded_idx)[0]
    original_cluster_indices = set(np.where(cluster_idx)[0])
    is_neighbor = np.zeros(expanded_idx.sum(), dtype=bool)
    for i, cell_idx in enumerate(expanded_cell_indices):
        if cell_idx not in original_cluster_indices:
            is_neighbor[i] = True

    return expanded_idx, is_neighbor


def _compute_child_lr(parent_final_lr, base_lr):
    """
    Compute child learning rate from parent's final LR.

    Uses exponent halving: if parent ended at 1e-8 (exp=-8),
    child starts at 1e-4 (exp=-4, half of -8).

    Returns base_lr if parent_final_lr is None.
    """
    if parent_final_lr is None or parent_final_lr <= 0:
        return base_lr

    parent_exp = math.log10(parent_final_lr)  # e.g., -8 for 1e-8
    child_exp = parent_exp / 2  # e.g., -4
    child_lr = 10 ** child_exp  # e.g., 1e-4

    return child_lr


def _get_parent_params(
    adata: AnnData,
    cluster: str,
    parent_cluster: Optional[str]
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
    """
    Retrieve W, I, gamma from parent cluster for initialization.

    Parameters
    ----------
    adata : AnnData
        Annotated data object with parent cluster already fitted
    cluster : str
        Current cluster being fitted (for logging)
    parent_cluster : str, optional
        Parent cluster name to retrieve parameters from

    Returns
    -------
    Tuple of (W, I, gamma) or (None, None, None) if no parent
    """
    if parent_cluster is None:
        return None, None, None

    W_key = f'W_{parent_cluster}'
    I_key = f'I_{parent_cluster}'

    if W_key not in adata.varp:
        print(f"  Warning: Parent W matrix '{W_key}' not found, using default initialization")
        return None, None, None

    W = adata.varp[W_key].copy()

    gene_indices = get_genes_used(adata)
    bias_vector = adata.var[I_key].values[gene_indices].copy()

    # Check for refitted gamma
    gamma_key = f'gamma_{parent_cluster}'
    if gamma_key in adata.var.columns:
        gamma = adata.var[gamma_key].values[gene_indices].copy()
    else:
        gamma = None

    return W, bias_vector, gamma


[docs] def fit_interactions( adata: AnnData, cluster_key: str, spliced_key: str = 'Ms', velocity_key: str = 'velocity_S', degradation_key: str = 'gamma', w_threshold: float = 1e-5, w_scaffold: Optional[np.ndarray] = None, scaffold_regularization: float = 1.0, reconstruction_regularization: float = 1.0, bias_regularization: float = 1.0, bias_bias: float = 0.0, only_TFs: bool = False, infer_I: bool = False, refit_gamma: bool = False, pre_initialize_W: bool = False, n_epochs: int = 1000, criterion: str = 'L2', batch_size: int = 64, device: str = 'cpu', skip_all: bool = False, learning_rate: float = 0.1, use_scheduler: bool = False, scheduler_kws: Optional[Dict] = None, use_plateau_scheduler: bool = False, plateau_patience: int = 50, plateau_factor: float = 0.5, plateau_min_lr: float = 1e-6, get_plots: bool = False, hierarchical_pretrain: bool = False, hierarchy_keys: Optional[List[str]] = None, hierarchy_mappings: Optional[List[Dict[str, str]]] = None, drop_last: bool = True, balanced_sampling: bool = False, normalize_regularization: bool = False, include_neighbors: bool = False, neighbors_key: str = 'connectivities', neighbor_fraction: float = 0.0, hierarchical_scaling: bool = False, copy: bool = False ) -> Optional[AnnData]: """ Infer gene regulatory network interaction matrices. Fits interaction matrix W and bias vector I for each cluster by solving: velocity = W * sigmoid(expression) - gamma * expression + I Parameters ---------- adata : AnnData Annotated data object with fitted sigmoid parameters cluster_key : str Key in adata.obs containing cluster annotations spliced_key : str, optional (default: 'Ms') Key in adata.layers for spliced counts velocity_key : str, optional (default: 'velocity_S') Key in adata.layers for RNA velocity degradation_key : str, optional (default: 'gamma') Key in adata.var for degradation rates w_threshold : float, optional (default: 1e-5) Threshold for pruning small interaction weights w_scaffold : np.ndarray, optional Binary scaffold matrix constraining network topology scaffold_regularization : float, optional (default: 1.0) Regularization strength for scaffold constraint reconstruction_regularization : float, optional (default: 1.0) Regularization strength for reconstruction loss bias_regularization : float, optional (default: 1.0) Regularization strength for bias vector bias_bias : float, optional (default: 0.0) Additional bias term to encourage bias values (e.g., negative bias_bias encourages more positive biases) only_TFs : bool, optional (default: False) If True, use masked linear layer (requires w_scaffold) infer_I : bool, optional (default: False) If True, infer bias vector I in least squares refit_gamma : bool, optional (default: False) If True, refit degradation rates during optimization pre_initialize_W : bool, optional (default: False) If True, initialize W with least squares solution n_epochs : int, optional (default: 1000) Number of training epochs criterion : str, optional (default: 'L2') Loss function: 'L1', 'L2', or 'MSE' batch_size : int, optional (default: 64) Batch size for training device : str, optional (default: 'cpu') Device for computation: 'cpu' or 'cuda' skip_all : bool, optional (default: False) If True, skip fitting on all cells combined learning_rate : float, optional (default: 0.1) Initial learning rate for training use_scheduler : bool, optional (default: False) If True, use StepLR learning rate scheduler scheduler_kws : dict, optional Keyword arguments for StepLR scheduler use_plateau_scheduler : bool, optional (default: False) If True, use ReduceLROnPlateau scheduler that decreases learning rate when the loss plateaus. This overrides use_scheduler. plateau_patience : int, optional (default: 50) Number of epochs with no improvement after which learning rate will be reduced plateau_factor : float, optional (default: 0.5) Factor by which the learning rate will be reduced (new_lr = lr * factor) plateau_min_lr : float, optional (default: 1e-6) Minimum learning rate for plateau scheduler get_plots : bool, optional (default: False) If True, show training plots hierarchical_pretrain : bool, optional (default: False) If True, enable hierarchical pretraining. First trains on all cells, then uses those parameters to initialize cluster-specific training. If hierarchy_keys is provided, trains through multiple levels. hierarchy_keys : list of str, optional List of obs keys from coarse to fine clustering (e.g., ['lineage', 'cell_type']). Only used if hierarchical_pretrain=True. If None, uses simple two-level hierarchy: 'all' → cluster_key. hierarchy_mappings : list of dict, optional List of mappings between consecutive hierarchy levels. Each mapping is {fine_cluster: coarse_cluster}. Must have len(hierarchy_keys) - 1 elements. Example: [{'T_cell': 'immune', 'B_cell': 'immune', 'Fibroblast': 'stromal'}] drop_last : bool, optional (default: True) If True, drop the last incomplete batch to ensure consistent batch sizes. This reduces gradient variance from small tail-end batches. balanced_sampling : bool, optional (default: False) If True, use weighted sampling to balance cluster representation when training on multiple clusters (e.g., when training 'all' during hierarchical pretraining). Requires hierarchical_pretrain=True. normalize_regularization : bool, optional (default: False) If True, normalize scaffold and bias regularization losses by batch size. This keeps regularization balanced with reconstruction loss when batch sizes vary. Alternative to drop_last for handling batch inconsistency. include_neighbors : bool, optional (default: False) If True, include neighboring cells (from any cluster) when training cluster-specific models. Neighbors are determined from the connectivity matrix. Only applies to non-'all' clusters. neighbors_key : str, optional (default: 'connectivities') Key in adata.obsp containing the cell-cell connectivity matrix. If not found, neighbors will be computed using scanpy. neighbor_fraction : float, optional (default: 0.0) Fraction of each training batch that should come from neighboring cells (cells not in the cluster but connected via the neighbor graph). Only applies when include_neighbors=True. Value must be in [0.0, 1.0). Example: 0.2 means 20% of each batch are neighbors, 80% cluster cells. hierarchical_scaling : bool, optional (default: False) If True and hierarchical_pretrain=True, use half epochs for pretraining levels (all levels except the finest) and adjust initial learning rate based on parent's final learning rate. Child levels start with LR exponent = parent_final_lr_exponent / 2 (e.g., parent ends at 1e-8, child starts at 1e-4). copy : bool, optional (default: False) If True, return a copy instead of modifying in-place Returns ------- AnnData or None Returns adata if copy=True, otherwise None. Adds to adata: - adata.varp[f'W_{cluster}']: interaction matrix for each cluster - adata.var[f'I_{cluster}']: bias vector for each cluster - adata.var[f'gamma_{cluster}']: refitted gamma if refit_gamma=True - adata.uns['scHopfield']['models'][cluster]: trained models if w_scaffold is provided """ adata = adata.copy() if copy else adata # Store keys for downstream functions if 'scHopfield' not in adata.uns: adata.uns['scHopfield'] = {} adata.uns['scHopfield']['cluster_key'] = cluster_key adata.uns['scHopfield']['spliced_key'] = spliced_key adata.uns['scHopfield']['velocity_key'] = velocity_key adata.uns['scHopfield']['degradation_key'] = degradation_key # Get gene indices genes = get_genes_used(adata) # Get data matrices x = to_numpy(get_matrix(adata, spliced_key, genes=genes)) v = to_numpy(get_matrix(adata, velocity_key, genes=genes)) g = adata.var[degradation_key].values[genes].astype(x.dtype) sig = get_matrix(adata, 'sigmoid', genes=genes) # Initialize storage for models if using scaffold if w_scaffold is not None: if 'models' not in adata.uns['scHopfield']: adata.uns['scHopfield']['models'] = {} if hierarchical_pretrain: # Build hierarchy levels levels = _build_hierarchy_levels(adata, cluster_key, hierarchy_keys, hierarchy_mappings) adata.uns['scHopfield']['hierarchy_levels'] = [(lvl[0], lvl[1]) for lvl in levels] for level_idx, (level_key, clusters, parent_mapping) in enumerate(levels): # Determine if this is a pretraining level (not the final level) is_final_level = (level_idx == len(levels) - 1) print(f"\n{'='*60}") print(f"=== Training Level {level_idx}: {level_key} ({len(clusters)} clusters) ===") print(f"{'='*60}") for cluster in clusters: # Get parent parameters if not first level parent_cluster = parent_mapping.get(cluster) if parent_mapping else None parent_W, parent_I, parent_gamma = _get_parent_params(adata, cluster, parent_cluster) if parent_cluster: print(f"\nInferring W and I for '{cluster}' (initialized from '{parent_cluster}')") else: print(f"\nInferring W and I for '{cluster}'") # Get cluster indices if cluster == 'all': idx = np.ones(adata.n_obs, dtype=bool) is_neighbor = None else: idx = adata.obs[level_key].values == cluster is_neighbor = None # Include neighbors if requested if include_neighbors: idx, is_neighbor = _get_cluster_with_neighbors(adata, idx, neighbors_key) # Use parent gamma if available, otherwise use default cluster_g = parent_gamma if parent_gamma is not None else g # Compute sample weights for balanced sampling (only for 'all' cluster) cluster_weights = None if balanced_sampling and cluster == 'all': cluster_weights = _compute_cluster_weights(adata, idx, cluster_key) # Compute epochs for this level level_epochs = n_epochs if hierarchical_scaling and not is_final_level: level_epochs = max(1, n_epochs // 2) print(f" Using {level_epochs} epochs (half for pretraining)") # Compute learning rate from parent's final LR level_lr = learning_rate if hierarchical_scaling and parent_cluster is not None: parent_final_lr_key = f'final_lr_{parent_cluster}' if parent_final_lr_key in adata.uns['scHopfield']: parent_final_lr = adata.uns['scHopfield'][parent_final_lr_key] level_lr = _compute_child_lr(parent_final_lr, learning_rate) print(f" Starting LR: {level_lr:.2e} (from parent final {parent_final_lr:.2e})") # Fit interactions for this cluster _fit_interactions_for_cluster( adata=adata, cluster=cluster, x=x[idx, :], v=v[idx, :], sig=sig[idx, :], g=cluster_g, w_threshold=w_threshold, w_scaffold=w_scaffold, scaffold_regularization=scaffold_regularization, reconstruction_regularization=reconstruction_regularization, bias_regularization=bias_regularization, bias_bias=bias_bias, only_TFs=only_TFs, infer_I=infer_I, refit_gamma=refit_gamma, pre_initialize_W=pre_initialize_W, n_epochs=level_epochs, criterion=criterion, batch_size=batch_size, device=device, learning_rate=level_lr, use_scheduler=use_scheduler, scheduler_kws=scheduler_kws, get_plots=get_plots, use_plateau_scheduler=use_plateau_scheduler, plateau_patience=plateau_patience, plateau_factor=plateau_factor, plateau_min_lr=plateau_min_lr, parent_W=parent_W, parent_I=parent_I, sample_weights=cluster_weights, drop_last=drop_last, normalize_regularization=normalize_regularization, is_neighbor=is_neighbor, neighbor_fraction=neighbor_fraction, ) # Store final learning rate if using scaffold (for hierarchical_scaling) if hierarchical_scaling and w_scaffold is not None: model = adata.uns['scHopfield']['models'].get(cluster) if model is not None and hasattr(model, 'lr_history') and model.lr_history: final_lr = model.lr_history[-1] adata.uns['scHopfield'][f'final_lr_{cluster}'] = final_lr else: # Original non-hierarchical behavior clusters = adata.obs[cluster_key].unique() if not skip_all: clusters = np.append(clusters, 'all') for cluster in clusters: print(f"Inferring interaction matrix W and bias vector I for cluster {cluster}") # Get cluster indices if cluster == 'all': idx = np.ones(adata.n_obs, dtype=bool) is_neighbor = None else: idx = adata.obs[cluster_key].values == cluster is_neighbor = None # Include neighbors if requested if include_neighbors: idx, is_neighbor = _get_cluster_with_neighbors(adata, idx, neighbors_key) # Fit interactions for this cluster _fit_interactions_for_cluster( adata=adata, cluster=cluster, x=x[idx, :], v=v[idx, :], sig=sig[idx, :], g=g, w_threshold=w_threshold, w_scaffold=w_scaffold, scaffold_regularization=scaffold_regularization, reconstruction_regularization=reconstruction_regularization, bias_regularization=bias_regularization, bias_bias=bias_bias, only_TFs=only_TFs, infer_I=infer_I, refit_gamma=refit_gamma, pre_initialize_W=pre_initialize_W, n_epochs=n_epochs, criterion=criterion, batch_size=batch_size, device=device, learning_rate=learning_rate, use_scheduler=use_scheduler, scheduler_kws=scheduler_kws, get_plots=get_plots, use_plateau_scheduler=use_plateau_scheduler, plateau_patience=plateau_patience, plateau_factor=plateau_factor, plateau_min_lr=plateau_min_lr, parent_W=None, parent_I=None, sample_weights=None, drop_last=drop_last, normalize_regularization=normalize_regularization, is_neighbor=is_neighbor, neighbor_fraction=neighbor_fraction, ) return adata if copy else None
def _fit_interactions_for_cluster( adata: AnnData, cluster: str, x: np.ndarray, v: np.ndarray, sig: np.ndarray, g: np.ndarray, w_threshold: float, w_scaffold: Optional[np.ndarray], scaffold_regularization: float, reconstruction_regularization: float, bias_regularization: float, bias_bias: float, only_TFs: bool, infer_I: bool, refit_gamma: bool, pre_initialize_W: bool, n_epochs: int, criterion: str, batch_size: int, device: str, learning_rate: float, use_scheduler: bool, scheduler_kws: Optional[Dict], use_plateau_scheduler: bool, plateau_patience: int, plateau_factor: float, plateau_min_lr: float, get_plots: bool, parent_W: Optional[np.ndarray] = None, parent_I: Optional[np.ndarray] = None, sample_weights: Optional[np.ndarray] = None, drop_last: bool = True, normalize_regularization: bool = False, is_neighbor: Optional[np.ndarray] = None, neighbor_fraction: float = 0.0, ): """ Fit interaction matrix W and bias I for a single cluster. This is adapted from Landscape._fit_interactions_for_group. Modifies adata in-place. """ if scheduler_kws is None: scheduler_kws = {} if device == "cuda" and torch.cuda.is_available(): device = torch.device("cuda") elif device == "mps" and torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") W = None bias_vector = None # Use parent parameters if provided (hierarchical pretraining) if parent_W is not None: W = parent_W.copy() bias_vector = parent_I.copy() if parent_I is not None else None print(" Using parent parameters as initialization") # Otherwise use least squares initialization elif (w_scaffold is None) or pre_initialize_W: rhs = np.hstack((sig, np.ones((sig.shape[0], 1), dtype=x.dtype))) if infer_I else sig try: WI = np.linalg.lstsq(rhs, v + g[None, :] * x, rcond=1e-5)[0] W = WI[:-1, :].T if infer_I else WI.T bias_vector = WI[-1, :] if infer_I else -np.clip(WI, a_min=None, a_max=0).sum(axis=0) except Exception: pass # Use ScaffoldOptimizer if scaffold provided if w_scaffold is not None: model = ScaffoldOptimizer( g, w_scaffold, device, refit_gamma, scaffold_regularization=scaffold_regularization, reconstruction_regularization=reconstruction_regularization, bias_regularization=bias_regularization, bias_bias=bias_bias, use_masked_linear=only_TFs, pre_initialized_W=W, pre_initialized_I=bias_vector, normalize_regularization=normalize_regularization ) train_loader = _create_train_loader( sig, v, x, device, batch_size, sample_weights=sample_weights, drop_last=drop_last, is_neighbor=is_neighbor, neighbor_fraction=neighbor_fraction, ) # Set up scheduler - plateau scheduler takes precedence if use_plateau_scheduler: scheduler_fn = None # Will use built-in plateau scheduler scheduler_kwargs = {} elif use_scheduler: scheduler_fn = torch.optim.lr_scheduler.StepLR scheduler_kwargs = {"step_size": 100, "gamma": 0.4} if scheduler_kws is None or scheduler_kws == {} else scheduler_kws else: scheduler_fn = None scheduler_kwargs = {} model.train_model( train_loader, n_epochs, learning_rate=learning_rate, criterion=criterion, scheduler_fn=scheduler_fn, scheduler_kwargs=scheduler_kwargs, use_plateau_scheduler=use_plateau_scheduler, plateau_patience=plateau_patience, plateau_factor=plateau_factor, plateau_min_lr=plateau_min_lr, get_plots=get_plots ) W = model.W.weight.detach().cpu().numpy() bias_vector = model.I.detach().cpu().numpy() g = np.exp(model.gamma.detach().cpu().numpy()) adata.uns['scHopfield']['models'][cluster] = model.cpu() # Threshold and store W[np.abs(W) < w_threshold] = 0 bias_vector[np.abs(bias_vector) < w_threshold] = 0 # Store interaction matrix in varp adata.varp[f'W_{cluster}'] = W # Store bias vector in var (one column per cluster) adata.var[f'I_{cluster}'] = 0.0 gene_indices = get_genes_used(adata) adata.var.iloc[gene_indices, adata.var.columns.get_loc(f'I_{cluster}')] = bias_vector # Store refitted gamma in var if applicable if refit_gamma: adata.var[f'gamma_{cluster}'] = 0.0 adata.var.iloc[gene_indices, adata.var.columns.get_loc(f'gamma_{cluster}')] = g def _compute_cluster_weights(adata, indices, cluster_key): """ Compute per-sample weights to balance cluster representation. Each sample gets weight = 1 / (frequency of its cluster). This ensures smaller clusters are over-sampled. Parameters ---------- adata : AnnData Annotated data object indices : np.ndarray Boolean array indicating which cells to include cluster_key : str Key in adata.obs containing cluster annotations Returns ------- np.ndarray Per-sample weights (same length as sum of indices) """ cluster_labels = adata.obs.loc[indices, cluster_key].values unique_clusters, inverse_indices, counts = np.unique( cluster_labels, return_inverse=True, return_counts=True ) # Weight = 1 / frequency weights_per_cluster = 1.0 / counts sample_weights = weights_per_cluster[inverse_indices] # Normalize so weights sum to len(sample_weights) sample_weights = sample_weights / sample_weights.sum() * len(sample_weights) return torch.DoubleTensor(sample_weights) class ControlledNeighborBatchSampler: """ Batch sampler that ensures each batch has a controlled composition of cluster cells vs neighbor cells. Parameters ---------- n_samples : int Total number of samples in dataset batch_size : int Size of each batch is_neighbor : np.ndarray Boolean array where True = neighbor, False = cluster cell neighbor_fraction : float Target fraction of neighbors in each batch (0.0 to 1.0) drop_last : bool Whether to drop the last incomplete batch """ def __init__(self, n_samples, batch_size, is_neighbor, neighbor_fraction, drop_last=True): self.batch_size = batch_size self.neighbor_fraction = neighbor_fraction self.drop_last = drop_last self.cluster_indices = np.where(~is_neighbor)[0] self.neighbor_indices = np.where(is_neighbor)[0] # Compute number of batches if drop_last: self.n_batches = n_samples // batch_size else: self.n_batches = (n_samples + batch_size - 1) // batch_size def __iter__(self): n_neighbor_per_batch = int(self.batch_size * self.neighbor_fraction) n_cluster_per_batch = self.batch_size - n_neighbor_per_batch for _ in range(self.n_batches): batch = [] # Sample cluster cells (with replacement if needed) if len(self.cluster_indices) > 0: cluster_samples = np.random.choice( self.cluster_indices, size=n_cluster_per_batch, replace=True ) batch.extend(cluster_samples.tolist()) # Sample neighbor cells (with replacement if needed) if len(self.neighbor_indices) > 0 and n_neighbor_per_batch > 0: neighbor_samples = np.random.choice( self.neighbor_indices, size=n_neighbor_per_batch, replace=True ) batch.extend(neighbor_samples.tolist()) # Shuffle the batch so neighbors aren't always at the end np.random.shuffle(batch) yield batch def __len__(self): return self.n_batches def _create_train_loader(sig, v, x, device, batch_size=64, sample_weights=None, drop_last=True, is_neighbor=None, neighbor_fraction=0.0): """ Helper to create PyTorch DataLoader. Parameters ---------- sig : np.ndarray Sigmoid-transformed expression matrix v : np.ndarray Velocity matrix x : np.ndarray Expression matrix device : str Device for computation batch_size : int Batch size for training sample_weights : np.ndarray, optional Per-sample weights for WeightedRandomSampler. If provided, weighted sampling is used instead of uniform sampling. drop_last : bool, optional (default: True) Drop the last incomplete batch to ensure consistent batch sizes is_neighbor : np.ndarray, optional Boolean array indicating which samples are neighbors (vs cluster cells). Required if neighbor_fraction > 0. neighbor_fraction : float, optional (default: 0.0) Target fraction of neighbors per batch. If > 0, uses ControlledNeighborBatchSampler. """ dataset = CustomDataset(sig, v, x, device) # Clamp batch_size to dataset size effective_batch_size = min(batch_size, len(dataset)) # Controlled neighbor sampling takes precedence if neighbor_fraction > 0 and is_neighbor is not None: batch_sampler = ControlledNeighborBatchSampler( n_samples=len(dataset), batch_size=effective_batch_size, is_neighbor=is_neighbor, neighbor_fraction=neighbor_fraction, drop_last=drop_last ) return torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler ) elif sample_weights is not None: sampler = WeightedRandomSampler( weights=sample_weights, num_samples=len(dataset), replacement=True ) return torch.utils.data.DataLoader( dataset, batch_size=effective_batch_size, sampler=sampler, drop_last=drop_last ) else: return torch.utils.data.DataLoader( dataset, batch_size=effective_batch_size, shuffle=True, drop_last=drop_last )