"""Inference of gene regulatory network interactions."""
import math
import numpy as np
import torch
from torch.utils.data import WeightedRandomSampler
from typing import Optional, List, Dict, Tuple
from anndata import AnnData
from .optimizer import ScaffoldOptimizer
from .datasets import CustomDataset
from .._utils.io import get_matrix, to_numpy, get_genes_used
def _build_hierarchy_levels(
adata: AnnData,
cluster_key: str,
hierarchy_keys: Optional[List[str]],
hierarchy_mappings: Optional[List[Dict[str, str]]]
) -> List[Tuple[str, List[str], Optional[Dict[str, str]]]]:
"""
Build the list of training levels from coarse to fine.
Parameters
----------
adata : AnnData
Annotated data object
cluster_key : str
Key for the finest level clustering (used if hierarchy_keys is None)
hierarchy_keys : list of str, optional
List of obs keys from coarse to fine (e.g., ['lineage', 'cell_type'])
hierarchy_mappings : list of dict, optional
List of {fine_cluster: coarse_cluster} mappings between consecutive levels
Returns
-------
List of tuples: (obs_key, clusters_at_this_level, mapping_to_parent)
The first level is always 'all' with mapping None.
"""
levels = [('all', ['all'], None)]
if hierarchy_keys is None:
# Simple case: all → cluster_key
clusters = list(adata.obs[cluster_key].unique())
mapping = {c: 'all' for c in clusters}
levels.append((cluster_key, clusters, mapping))
else:
# Multi-level hierarchy
if hierarchy_mappings is None or len(hierarchy_mappings) != len(hierarchy_keys) - 1:
raise ValueError(
f"hierarchy_mappings must have {len(hierarchy_keys) - 1} elements "
f"(one fewer than hierarchy_keys), got {len(hierarchy_mappings) if hierarchy_mappings else 0}"
)
# First level after 'all' - the coarsest clustering
coarse_key = hierarchy_keys[0]
coarse_clusters = list(adata.obs[coarse_key].unique())
coarse_mapping = {c: 'all' for c in coarse_clusters}
levels.append((coarse_key, coarse_clusters, coarse_mapping))
# Subsequent levels
for i, fine_key in enumerate(hierarchy_keys[1:]):
fine_clusters = list(adata.obs[fine_key].unique())
fine_to_coarse = hierarchy_mappings[i]
# Validate mapping
for fc in fine_clusters:
if fc not in fine_to_coarse:
raise ValueError(
f"Fine cluster '{fc}' from '{fine_key}' not found in hierarchy_mappings[{i}]"
)
levels.append((fine_key, fine_clusters, fine_to_coarse))
return levels
def _get_cluster_with_neighbors(adata, cluster_idx, neighbors_key='connectivities'):
"""
Expand cluster selection to include neighboring cells.
Parameters
----------
adata : AnnData
Annotated data object
cluster_idx : np.ndarray
Boolean array of cells in the cluster
neighbors_key : str
Key in adata.obsp for connectivity matrix
Returns
-------
tuple
(expanded_idx, is_neighbor) where:
- expanded_idx: Boolean array including cluster cells and neighbors
- is_neighbor: Boolean array (same size as expanded_idx.sum()) where
True indicates the cell is a neighbor, False indicates cluster member
"""
import scanpy as sc
from scipy import sparse
# Get or compute connectivity matrix
if neighbors_key not in adata.obsp:
print(f" Computing neighbors ('{neighbors_key}' not found in adata.obsp)")
sc.pp.neighbors(adata)
conn = adata.obsp[neighbors_key]
# Get indices of cluster cells
cluster_cell_indices = np.where(cluster_idx)[0]
# Find all neighbors of cluster cells
if sparse.issparse(conn):
# Get rows for cluster cells, find non-zero columns (neighbors)
neighbor_rows = conn[cluster_cell_indices, :]
neighbor_indices = set(neighbor_rows.nonzero()[1])
else:
neighbor_indices = set()
for i in cluster_cell_indices:
neighbor_indices.update(np.where(conn[i, :] > 0)[0])
# Create expanded boolean mask
expanded_idx = cluster_idx.copy()
for ni in neighbor_indices:
expanded_idx[ni] = True
n_original = cluster_idx.sum()
n_expanded = expanded_idx.sum()
n_neighbors = n_expanded - n_original
print(f" Including {n_neighbors} neighboring cells ({n_original} cluster + {n_neighbors} neighbors = {n_expanded} total)")
# Create is_neighbor mask for the expanded selection
# Mark which indices in the expanded set are neighbors
expanded_cell_indices = np.where(expanded_idx)[0]
original_cluster_indices = set(np.where(cluster_idx)[0])
is_neighbor = np.zeros(expanded_idx.sum(), dtype=bool)
for i, cell_idx in enumerate(expanded_cell_indices):
if cell_idx not in original_cluster_indices:
is_neighbor[i] = True
return expanded_idx, is_neighbor
def _compute_child_lr(parent_final_lr, base_lr):
"""
Compute child learning rate from parent's final LR.
Uses exponent halving: if parent ended at 1e-8 (exp=-8),
child starts at 1e-4 (exp=-4, half of -8).
Returns base_lr if parent_final_lr is None.
"""
if parent_final_lr is None or parent_final_lr <= 0:
return base_lr
parent_exp = math.log10(parent_final_lr) # e.g., -8 for 1e-8
child_exp = parent_exp / 2 # e.g., -4
child_lr = 10 ** child_exp # e.g., 1e-4
return child_lr
def _get_parent_params(
adata: AnnData,
cluster: str,
parent_cluster: Optional[str]
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
"""
Retrieve W, I, gamma from parent cluster for initialization.
Parameters
----------
adata : AnnData
Annotated data object with parent cluster already fitted
cluster : str
Current cluster being fitted (for logging)
parent_cluster : str, optional
Parent cluster name to retrieve parameters from
Returns
-------
Tuple of (W, I, gamma) or (None, None, None) if no parent
"""
if parent_cluster is None:
return None, None, None
W_key = f'W_{parent_cluster}'
I_key = f'I_{parent_cluster}'
if W_key not in adata.varp:
print(f" Warning: Parent W matrix '{W_key}' not found, using default initialization")
return None, None, None
W = adata.varp[W_key].copy()
gene_indices = get_genes_used(adata)
bias_vector = adata.var[I_key].values[gene_indices].copy()
# Check for refitted gamma
gamma_key = f'gamma_{parent_cluster}'
if gamma_key in adata.var.columns:
gamma = adata.var[gamma_key].values[gene_indices].copy()
else:
gamma = None
return W, bias_vector, gamma
[docs]
def fit_interactions(
adata: AnnData,
cluster_key: str,
spliced_key: str = 'Ms',
velocity_key: str = 'velocity_S',
degradation_key: str = 'gamma',
w_threshold: float = 1e-5,
w_scaffold: Optional[np.ndarray] = None,
scaffold_regularization: float = 1.0,
reconstruction_regularization: float = 1.0,
bias_regularization: float = 1.0,
bias_bias: float = 0.0,
only_TFs: bool = False,
infer_I: bool = False,
refit_gamma: bool = False,
pre_initialize_W: bool = False,
n_epochs: int = 1000,
criterion: str = 'L2',
batch_size: int = 64,
device: str = 'cpu',
skip_all: bool = False,
learning_rate: float = 0.1,
use_scheduler: bool = False,
scheduler_kws: Optional[Dict] = None,
use_plateau_scheduler: bool = False,
plateau_patience: int = 50,
plateau_factor: float = 0.5,
plateau_min_lr: float = 1e-6,
get_plots: bool = False,
hierarchical_pretrain: bool = False,
hierarchy_keys: Optional[List[str]] = None,
hierarchy_mappings: Optional[List[Dict[str, str]]] = None,
drop_last: bool = True,
balanced_sampling: bool = False,
normalize_regularization: bool = False,
include_neighbors: bool = False,
neighbors_key: str = 'connectivities',
neighbor_fraction: float = 0.0,
hierarchical_scaling: bool = False,
copy: bool = False
) -> Optional[AnnData]:
"""
Infer gene regulatory network interaction matrices.
Fits interaction matrix W and bias vector I for each cluster by
solving: velocity = W * sigmoid(expression) - gamma * expression + I
Parameters
----------
adata : AnnData
Annotated data object with fitted sigmoid parameters
cluster_key : str
Key in adata.obs containing cluster annotations
spliced_key : str, optional (default: 'Ms')
Key in adata.layers for spliced counts
velocity_key : str, optional (default: 'velocity_S')
Key in adata.layers for RNA velocity
degradation_key : str, optional (default: 'gamma')
Key in adata.var for degradation rates
w_threshold : float, optional (default: 1e-5)
Threshold for pruning small interaction weights
w_scaffold : np.ndarray, optional
Binary scaffold matrix constraining network topology
scaffold_regularization : float, optional (default: 1.0)
Regularization strength for scaffold constraint
reconstruction_regularization : float, optional (default: 1.0)
Regularization strength for reconstruction loss
bias_regularization : float, optional (default: 1.0)
Regularization strength for bias vector
bias_bias : float, optional (default: 0.0)
Additional bias term to encourage bias values (e.g., negative bias_bias encourages more positive biases)
only_TFs : bool, optional (default: False)
If True, use masked linear layer (requires w_scaffold)
infer_I : bool, optional (default: False)
If True, infer bias vector I in least squares
refit_gamma : bool, optional (default: False)
If True, refit degradation rates during optimization
pre_initialize_W : bool, optional (default: False)
If True, initialize W with least squares solution
n_epochs : int, optional (default: 1000)
Number of training epochs
criterion : str, optional (default: 'L2')
Loss function: 'L1', 'L2', or 'MSE'
batch_size : int, optional (default: 64)
Batch size for training
device : str, optional (default: 'cpu')
Device for computation: 'cpu' or 'cuda'
skip_all : bool, optional (default: False)
If True, skip fitting on all cells combined
learning_rate : float, optional (default: 0.1)
Initial learning rate for training
use_scheduler : bool, optional (default: False)
If True, use StepLR learning rate scheduler
scheduler_kws : dict, optional
Keyword arguments for StepLR scheduler
use_plateau_scheduler : bool, optional (default: False)
If True, use ReduceLROnPlateau scheduler that decreases learning rate
when the loss plateaus. This overrides use_scheduler.
plateau_patience : int, optional (default: 50)
Number of epochs with no improvement after which learning rate will be reduced
plateau_factor : float, optional (default: 0.5)
Factor by which the learning rate will be reduced (new_lr = lr * factor)
plateau_min_lr : float, optional (default: 1e-6)
Minimum learning rate for plateau scheduler
get_plots : bool, optional (default: False)
If True, show training plots
hierarchical_pretrain : bool, optional (default: False)
If True, enable hierarchical pretraining. First trains on all cells,
then uses those parameters to initialize cluster-specific training.
If hierarchy_keys is provided, trains through multiple levels.
hierarchy_keys : list of str, optional
List of obs keys from coarse to fine clustering (e.g., ['lineage', 'cell_type']).
Only used if hierarchical_pretrain=True. If None, uses simple two-level
hierarchy: 'all' → cluster_key.
hierarchy_mappings : list of dict, optional
List of mappings between consecutive hierarchy levels. Each mapping is
{fine_cluster: coarse_cluster}. Must have len(hierarchy_keys) - 1 elements.
Example: [{'T_cell': 'immune', 'B_cell': 'immune', 'Fibroblast': 'stromal'}]
drop_last : bool, optional (default: True)
If True, drop the last incomplete batch to ensure consistent batch sizes.
This reduces gradient variance from small tail-end batches.
balanced_sampling : bool, optional (default: False)
If True, use weighted sampling to balance cluster representation when
training on multiple clusters (e.g., when training 'all' during
hierarchical pretraining). Requires hierarchical_pretrain=True.
normalize_regularization : bool, optional (default: False)
If True, normalize scaffold and bias regularization losses by batch size.
This keeps regularization balanced with reconstruction loss when batch
sizes vary. Alternative to drop_last for handling batch inconsistency.
include_neighbors : bool, optional (default: False)
If True, include neighboring cells (from any cluster) when training
cluster-specific models. Neighbors are determined from the connectivity
matrix. Only applies to non-'all' clusters.
neighbors_key : str, optional (default: 'connectivities')
Key in adata.obsp containing the cell-cell connectivity matrix.
If not found, neighbors will be computed using scanpy.
neighbor_fraction : float, optional (default: 0.0)
Fraction of each training batch that should come from neighboring cells
(cells not in the cluster but connected via the neighbor graph).
Only applies when include_neighbors=True. Value must be in [0.0, 1.0).
Example: 0.2 means 20% of each batch are neighbors, 80% cluster cells.
hierarchical_scaling : bool, optional (default: False)
If True and hierarchical_pretrain=True, use half epochs for pretraining
levels (all levels except the finest) and adjust initial learning rate
based on parent's final learning rate. Child levels start with LR
exponent = parent_final_lr_exponent / 2 (e.g., parent ends at 1e-8,
child starts at 1e-4).
copy : bool, optional (default: False)
If True, return a copy instead of modifying in-place
Returns
-------
AnnData or None
Returns adata if copy=True, otherwise None.
Adds to adata:
- adata.varp[f'W_{cluster}']: interaction matrix for each cluster
- adata.var[f'I_{cluster}']: bias vector for each cluster
- adata.var[f'gamma_{cluster}']: refitted gamma if refit_gamma=True
- adata.uns['scHopfield']['models'][cluster]: trained models if w_scaffold is provided
"""
adata = adata.copy() if copy else adata
# Store keys for downstream functions
if 'scHopfield' not in adata.uns:
adata.uns['scHopfield'] = {}
adata.uns['scHopfield']['cluster_key'] = cluster_key
adata.uns['scHopfield']['spliced_key'] = spliced_key
adata.uns['scHopfield']['velocity_key'] = velocity_key
adata.uns['scHopfield']['degradation_key'] = degradation_key
# Get gene indices
genes = get_genes_used(adata)
# Get data matrices
x = to_numpy(get_matrix(adata, spliced_key, genes=genes))
v = to_numpy(get_matrix(adata, velocity_key, genes=genes))
g = adata.var[degradation_key].values[genes].astype(x.dtype)
sig = get_matrix(adata, 'sigmoid', genes=genes)
# Initialize storage for models if using scaffold
if w_scaffold is not None:
if 'models' not in adata.uns['scHopfield']:
adata.uns['scHopfield']['models'] = {}
if hierarchical_pretrain:
# Build hierarchy levels
levels = _build_hierarchy_levels(adata, cluster_key, hierarchy_keys, hierarchy_mappings)
adata.uns['scHopfield']['hierarchy_levels'] = [(lvl[0], lvl[1]) for lvl in levels]
for level_idx, (level_key, clusters, parent_mapping) in enumerate(levels):
# Determine if this is a pretraining level (not the final level)
is_final_level = (level_idx == len(levels) - 1)
print(f"\n{'='*60}")
print(f"=== Training Level {level_idx}: {level_key} ({len(clusters)} clusters) ===")
print(f"{'='*60}")
for cluster in clusters:
# Get parent parameters if not first level
parent_cluster = parent_mapping.get(cluster) if parent_mapping else None
parent_W, parent_I, parent_gamma = _get_parent_params(adata, cluster, parent_cluster)
if parent_cluster:
print(f"\nInferring W and I for '{cluster}' (initialized from '{parent_cluster}')")
else:
print(f"\nInferring W and I for '{cluster}'")
# Get cluster indices
if cluster == 'all':
idx = np.ones(adata.n_obs, dtype=bool)
is_neighbor = None
else:
idx = adata.obs[level_key].values == cluster
is_neighbor = None
# Include neighbors if requested
if include_neighbors:
idx, is_neighbor = _get_cluster_with_neighbors(adata, idx, neighbors_key)
# Use parent gamma if available, otherwise use default
cluster_g = parent_gamma if parent_gamma is not None else g
# Compute sample weights for balanced sampling (only for 'all' cluster)
cluster_weights = None
if balanced_sampling and cluster == 'all':
cluster_weights = _compute_cluster_weights(adata, idx, cluster_key)
# Compute epochs for this level
level_epochs = n_epochs
if hierarchical_scaling and not is_final_level:
level_epochs = max(1, n_epochs // 2)
print(f" Using {level_epochs} epochs (half for pretraining)")
# Compute learning rate from parent's final LR
level_lr = learning_rate
if hierarchical_scaling and parent_cluster is not None:
parent_final_lr_key = f'final_lr_{parent_cluster}'
if parent_final_lr_key in adata.uns['scHopfield']:
parent_final_lr = adata.uns['scHopfield'][parent_final_lr_key]
level_lr = _compute_child_lr(parent_final_lr, learning_rate)
print(f" Starting LR: {level_lr:.2e} (from parent final {parent_final_lr:.2e})")
# Fit interactions for this cluster
_fit_interactions_for_cluster(
adata=adata,
cluster=cluster,
x=x[idx, :],
v=v[idx, :],
sig=sig[idx, :],
g=cluster_g,
w_threshold=w_threshold,
w_scaffold=w_scaffold,
scaffold_regularization=scaffold_regularization,
reconstruction_regularization=reconstruction_regularization,
bias_regularization=bias_regularization,
bias_bias=bias_bias,
only_TFs=only_TFs,
infer_I=infer_I,
refit_gamma=refit_gamma,
pre_initialize_W=pre_initialize_W,
n_epochs=level_epochs,
criterion=criterion,
batch_size=batch_size,
device=device,
learning_rate=level_lr,
use_scheduler=use_scheduler,
scheduler_kws=scheduler_kws,
get_plots=get_plots,
use_plateau_scheduler=use_plateau_scheduler,
plateau_patience=plateau_patience,
plateau_factor=plateau_factor,
plateau_min_lr=plateau_min_lr,
parent_W=parent_W,
parent_I=parent_I,
sample_weights=cluster_weights,
drop_last=drop_last,
normalize_regularization=normalize_regularization,
is_neighbor=is_neighbor,
neighbor_fraction=neighbor_fraction,
)
# Store final learning rate if using scaffold (for hierarchical_scaling)
if hierarchical_scaling and w_scaffold is not None:
model = adata.uns['scHopfield']['models'].get(cluster)
if model is not None and hasattr(model, 'lr_history') and model.lr_history:
final_lr = model.lr_history[-1]
adata.uns['scHopfield'][f'final_lr_{cluster}'] = final_lr
else:
# Original non-hierarchical behavior
clusters = adata.obs[cluster_key].unique()
if not skip_all:
clusters = np.append(clusters, 'all')
for cluster in clusters:
print(f"Inferring interaction matrix W and bias vector I for cluster {cluster}")
# Get cluster indices
if cluster == 'all':
idx = np.ones(adata.n_obs, dtype=bool)
is_neighbor = None
else:
idx = adata.obs[cluster_key].values == cluster
is_neighbor = None
# Include neighbors if requested
if include_neighbors:
idx, is_neighbor = _get_cluster_with_neighbors(adata, idx, neighbors_key)
# Fit interactions for this cluster
_fit_interactions_for_cluster(
adata=adata,
cluster=cluster,
x=x[idx, :],
v=v[idx, :],
sig=sig[idx, :],
g=g,
w_threshold=w_threshold,
w_scaffold=w_scaffold,
scaffold_regularization=scaffold_regularization,
reconstruction_regularization=reconstruction_regularization,
bias_regularization=bias_regularization,
bias_bias=bias_bias,
only_TFs=only_TFs,
infer_I=infer_I,
refit_gamma=refit_gamma,
pre_initialize_W=pre_initialize_W,
n_epochs=n_epochs,
criterion=criterion,
batch_size=batch_size,
device=device,
learning_rate=learning_rate,
use_scheduler=use_scheduler,
scheduler_kws=scheduler_kws,
get_plots=get_plots,
use_plateau_scheduler=use_plateau_scheduler,
plateau_patience=plateau_patience,
plateau_factor=plateau_factor,
plateau_min_lr=plateau_min_lr,
parent_W=None,
parent_I=None,
sample_weights=None,
drop_last=drop_last,
normalize_regularization=normalize_regularization,
is_neighbor=is_neighbor,
neighbor_fraction=neighbor_fraction,
)
return adata if copy else None
def _fit_interactions_for_cluster(
adata: AnnData,
cluster: str,
x: np.ndarray,
v: np.ndarray,
sig: np.ndarray,
g: np.ndarray,
w_threshold: float,
w_scaffold: Optional[np.ndarray],
scaffold_regularization: float,
reconstruction_regularization: float,
bias_regularization: float,
bias_bias: float,
only_TFs: bool,
infer_I: bool,
refit_gamma: bool,
pre_initialize_W: bool,
n_epochs: int,
criterion: str,
batch_size: int,
device: str,
learning_rate: float,
use_scheduler: bool,
scheduler_kws: Optional[Dict],
use_plateau_scheduler: bool,
plateau_patience: int,
plateau_factor: float,
plateau_min_lr: float,
get_plots: bool,
parent_W: Optional[np.ndarray] = None,
parent_I: Optional[np.ndarray] = None,
sample_weights: Optional[np.ndarray] = None,
drop_last: bool = True,
normalize_regularization: bool = False,
is_neighbor: Optional[np.ndarray] = None,
neighbor_fraction: float = 0.0,
):
"""
Fit interaction matrix W and bias I for a single cluster.
This is adapted from Landscape._fit_interactions_for_group.
Modifies adata in-place.
"""
if scheduler_kws is None:
scheduler_kws = {}
if device == "cuda" and torch.cuda.is_available():
device = torch.device("cuda")
elif device == "mps" and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
W = None
bias_vector = None
# Use parent parameters if provided (hierarchical pretraining)
if parent_W is not None:
W = parent_W.copy()
bias_vector = parent_I.copy() if parent_I is not None else None
print(" Using parent parameters as initialization")
# Otherwise use least squares initialization
elif (w_scaffold is None) or pre_initialize_W:
rhs = np.hstack((sig, np.ones((sig.shape[0], 1), dtype=x.dtype))) if infer_I else sig
try:
WI = np.linalg.lstsq(rhs, v + g[None, :] * x, rcond=1e-5)[0]
W = WI[:-1, :].T if infer_I else WI.T
bias_vector = WI[-1, :] if infer_I else -np.clip(WI, a_min=None, a_max=0).sum(axis=0)
except Exception:
pass
# Use ScaffoldOptimizer if scaffold provided
if w_scaffold is not None:
model = ScaffoldOptimizer(
g, w_scaffold, device, refit_gamma,
scaffold_regularization=scaffold_regularization,
reconstruction_regularization=reconstruction_regularization,
bias_regularization=bias_regularization,
bias_bias=bias_bias,
use_masked_linear=only_TFs,
pre_initialized_W=W,
pre_initialized_I=bias_vector,
normalize_regularization=normalize_regularization
)
train_loader = _create_train_loader(
sig, v, x, device, batch_size,
sample_weights=sample_weights,
drop_last=drop_last,
is_neighbor=is_neighbor,
neighbor_fraction=neighbor_fraction,
)
# Set up scheduler - plateau scheduler takes precedence
if use_plateau_scheduler:
scheduler_fn = None # Will use built-in plateau scheduler
scheduler_kwargs = {}
elif use_scheduler:
scheduler_fn = torch.optim.lr_scheduler.StepLR
scheduler_kwargs = {"step_size": 100, "gamma": 0.4} if scheduler_kws is None or scheduler_kws == {} else scheduler_kws
else:
scheduler_fn = None
scheduler_kwargs = {}
model.train_model(
train_loader, n_epochs,
learning_rate=learning_rate,
criterion=criterion,
scheduler_fn=scheduler_fn,
scheduler_kwargs=scheduler_kwargs,
use_plateau_scheduler=use_plateau_scheduler,
plateau_patience=plateau_patience,
plateau_factor=plateau_factor,
plateau_min_lr=plateau_min_lr,
get_plots=get_plots
)
W = model.W.weight.detach().cpu().numpy()
bias_vector = model.I.detach().cpu().numpy()
g = np.exp(model.gamma.detach().cpu().numpy())
adata.uns['scHopfield']['models'][cluster] = model.cpu()
# Threshold and store
W[np.abs(W) < w_threshold] = 0
bias_vector[np.abs(bias_vector) < w_threshold] = 0
# Store interaction matrix in varp
adata.varp[f'W_{cluster}'] = W
# Store bias vector in var (one column per cluster)
adata.var[f'I_{cluster}'] = 0.0
gene_indices = get_genes_used(adata)
adata.var.iloc[gene_indices, adata.var.columns.get_loc(f'I_{cluster}')] = bias_vector
# Store refitted gamma in var if applicable
if refit_gamma:
adata.var[f'gamma_{cluster}'] = 0.0
adata.var.iloc[gene_indices, adata.var.columns.get_loc(f'gamma_{cluster}')] = g
def _compute_cluster_weights(adata, indices, cluster_key):
"""
Compute per-sample weights to balance cluster representation.
Each sample gets weight = 1 / (frequency of its cluster).
This ensures smaller clusters are over-sampled.
Parameters
----------
adata : AnnData
Annotated data object
indices : np.ndarray
Boolean array indicating which cells to include
cluster_key : str
Key in adata.obs containing cluster annotations
Returns
-------
np.ndarray
Per-sample weights (same length as sum of indices)
"""
cluster_labels = adata.obs.loc[indices, cluster_key].values
unique_clusters, inverse_indices, counts = np.unique(
cluster_labels, return_inverse=True, return_counts=True
)
# Weight = 1 / frequency
weights_per_cluster = 1.0 / counts
sample_weights = weights_per_cluster[inverse_indices]
# Normalize so weights sum to len(sample_weights)
sample_weights = sample_weights / sample_weights.sum() * len(sample_weights)
return torch.DoubleTensor(sample_weights)
class ControlledNeighborBatchSampler:
"""
Batch sampler that ensures each batch has a controlled composition
of cluster cells vs neighbor cells.
Parameters
----------
n_samples : int
Total number of samples in dataset
batch_size : int
Size of each batch
is_neighbor : np.ndarray
Boolean array where True = neighbor, False = cluster cell
neighbor_fraction : float
Target fraction of neighbors in each batch (0.0 to 1.0)
drop_last : bool
Whether to drop the last incomplete batch
"""
def __init__(self, n_samples, batch_size, is_neighbor, neighbor_fraction, drop_last=True):
self.batch_size = batch_size
self.neighbor_fraction = neighbor_fraction
self.drop_last = drop_last
self.cluster_indices = np.where(~is_neighbor)[0]
self.neighbor_indices = np.where(is_neighbor)[0]
# Compute number of batches
if drop_last:
self.n_batches = n_samples // batch_size
else:
self.n_batches = (n_samples + batch_size - 1) // batch_size
def __iter__(self):
n_neighbor_per_batch = int(self.batch_size * self.neighbor_fraction)
n_cluster_per_batch = self.batch_size - n_neighbor_per_batch
for _ in range(self.n_batches):
batch = []
# Sample cluster cells (with replacement if needed)
if len(self.cluster_indices) > 0:
cluster_samples = np.random.choice(
self.cluster_indices,
size=n_cluster_per_batch,
replace=True
)
batch.extend(cluster_samples.tolist())
# Sample neighbor cells (with replacement if needed)
if len(self.neighbor_indices) > 0 and n_neighbor_per_batch > 0:
neighbor_samples = np.random.choice(
self.neighbor_indices,
size=n_neighbor_per_batch,
replace=True
)
batch.extend(neighbor_samples.tolist())
# Shuffle the batch so neighbors aren't always at the end
np.random.shuffle(batch)
yield batch
def __len__(self):
return self.n_batches
def _create_train_loader(sig, v, x, device, batch_size=64, sample_weights=None, drop_last=True,
is_neighbor=None, neighbor_fraction=0.0):
"""
Helper to create PyTorch DataLoader.
Parameters
----------
sig : np.ndarray
Sigmoid-transformed expression matrix
v : np.ndarray
Velocity matrix
x : np.ndarray
Expression matrix
device : str
Device for computation
batch_size : int
Batch size for training
sample_weights : np.ndarray, optional
Per-sample weights for WeightedRandomSampler. If provided,
weighted sampling is used instead of uniform sampling.
drop_last : bool, optional (default: True)
Drop the last incomplete batch to ensure consistent batch sizes
is_neighbor : np.ndarray, optional
Boolean array indicating which samples are neighbors (vs cluster cells).
Required if neighbor_fraction > 0.
neighbor_fraction : float, optional (default: 0.0)
Target fraction of neighbors per batch. If > 0, uses ControlledNeighborBatchSampler.
"""
dataset = CustomDataset(sig, v, x, device)
# Clamp batch_size to dataset size
effective_batch_size = min(batch_size, len(dataset))
# Controlled neighbor sampling takes precedence
if neighbor_fraction > 0 and is_neighbor is not None:
batch_sampler = ControlledNeighborBatchSampler(
n_samples=len(dataset),
batch_size=effective_batch_size,
is_neighbor=is_neighbor,
neighbor_fraction=neighbor_fraction,
drop_last=drop_last
)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler
)
elif sample_weights is not None:
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(dataset),
replacement=True
)
return torch.utils.data.DataLoader(
dataset,
batch_size=effective_batch_size,
sampler=sampler,
drop_last=drop_last
)
else:
return torch.utils.data.DataLoader(
dataset,
batch_size=effective_batch_size,
shuffle=True,
drop_last=drop_last
)