"""Simulation utilities for gene regulatory network dynamics."""
import numpy as np
from typing import Optional, List, Dict, Union
from anndata import AnnData
from tqdm.auto import tqdm
from joblib import Parallel, delayed
from .solver import create_solver, ODESolver
from .._utils.io import get_matrix, to_numpy, get_genes_used
from ._utils import _parse_perturb_genes, _update_scHopfield_uns
# torchdiffeq method names that can be used on the GPU path
_TORCHDIFFEQ_METHODS = frozenset([
'euler', 'rk4', 'midpoint', 'dopri5', 'dopri8',
'bosh3', 'adaptive_heun', 'fehlberg2',
])
def _run_jobs(func, items, n_jobs):
"""Run func(item) for each item. Sequential when n_jobs=1, threaded otherwise.
Falls back to sequential if parallel execution raises any exception."""
if n_jobs == 1:
return [func(item) for item in items]
try:
return Parallel(n_jobs=n_jobs, prefer='threads')(
delayed(func)(item) for item in items
)
except Exception:
return [func(item) for item in items]
[docs]
def simulate_trajectory(
adata: AnnData,
cluster: str,
cell_idx: Union[int, List[int]],
t_span: np.ndarray,
spliced_key: str = 'Ms',
degradation_key: str = 'gamma',
method: str = 'euler',
x_max_percentile: float = 99.0,
n_jobs: int = 1,
verbose: bool = False
) -> Union[np.ndarray, List[np.ndarray]]:
"""
Simulate trajectory from one or more cells' initial states.
Parameters
----------
adata : AnnData
Annotated data object with fitted interactions
cluster : str
Cluster name
cell_idx : int or list of int
Index or list of indices of cells to use as initial conditions.
Returns a single trajectory array for a scalar, or a list for a list.
t_span : np.ndarray
Time points for simulation
spliced_key : str, optional
Key for expression data
degradation_key : str, optional
Key for degradation rates
method : str, optional (default: 'euler')
Integration method:
- 'euler': Simple Euler method with clipping (stable, recommended)
- 'odeint': scipy.integrate.odeint
- 'RK45', 'RK23', etc.: scipy.integrate.solve_ivp methods
x_max_percentile : float, optional (default: 99.0)
Percentile of expression to use as upper bound. Prevents divergence.
Set to None to disable upper bound.
n_jobs : int, optional (default: 1)
Number of parallel jobs when cell_idx is a list. 1 = sequential,
-1 = all cores. Uses threads; no effect for a single cell.
verbose : bool, optional (default: False)
Print simulation info
Returns
-------
np.ndarray or list of np.ndarray
Trajectory (len(t_span) × n_genes) for a scalar cell_idx,
or a list of trajectories for a list input.
"""
single = isinstance(cell_idx, (int, np.integer))
indices = [int(cell_idx)] if single else list(cell_idx)
genes = get_genes_used(adata)
X_all = to_numpy(get_matrix(adata, spliced_key, genes=genes))
solver = create_solver(
adata, cluster, degradation_key,
spliced_key=spliced_key,
x_max_percentile=x_max_percentile
)
if verbose:
print(f"Simulating {len(indices)} trajectory/ies in cluster '{cluster}'")
print(f" Method: {method}")
print(f" Time span: {t_span[0]:.2f} to {t_span[-1]:.2f} ({len(t_span)} points)")
if solver.x_max is not None:
print(f" Upper bound: {x_max_percentile}th percentile × 2")
def _run_one(idx):
x0 = np.maximum(X_all[idx].copy(), 0)
return solver.solve(x0, t_span, method=method, clip_each_step=True)
results = _run_jobs(_run_one, indices, n_jobs)
if verbose:
print(f" Final values range: [{results[-1][-1].min():.3f}, {results[-1][-1].max():.3f}]")
return results[0] if single else results
def simulate_perturbation_ode(
adata: AnnData,
cluster: str,
cell_idx: Union[int, List[int]],
gene_perturbations: dict,
t_span: np.ndarray,
spliced_key: str = 'Ms',
degradation_key: str = 'gamma',
method: str = 'euler',
x_max_percentile: float = 99.0,
residual_gene_dynamics: bool = False,
n_jobs: int = 1,
verbose: bool = False
) -> Union[np.ndarray, List[np.ndarray]]:
"""
Simulate trajectory with gene perturbations using ODE integration.
By default, perturbed genes (KO/OE) are held fixed at their perturbed values
throughout the entire simulation. Set residual_gene_dynamics=True to allow
perturbed genes to evolve according to the ODE dynamics after initial perturbation.
For CellOracle-style perturbation simulation, use sch.dyn.simulate_perturbation instead.
Parameters
----------
adata : AnnData
Annotated data object
cluster : str
Cluster name
cell_idx : int or list of int
Cell index or list of indices for initial conditions.
Returns a single trajectory for a scalar, or a list for a list.
gene_perturbations : dict
Dictionary mapping gene names to perturbation values.
- Knockout: {"Gata1": 0.0} sets Gata1 to 0
- Overexpression: {"Gata1": 5.0} sets Gata1 to 5.0
t_span : np.ndarray
Time points
spliced_key : str, optional
Expression data key
degradation_key : str, optional
Degradation rates key
method : str, optional (default: 'euler')
Integration method ('euler', 'odeint', 'RK45', etc.)
x_max_percentile : float, optional (default: 99.0)
Percentile for upper bound. Set to None to disable.
residual_gene_dynamics : bool, optional (default: False)
If False, perturbed genes are held fixed at their perturbed values.
If True, perturbed genes can evolve according to ODE dynamics after
the initial perturbation is applied.
n_jobs : int, optional (default: 1)
Number of parallel jobs when cell_idx is a list. 1 = sequential,
-1 = all cores. Uses threads; no effect for a single cell.
verbose : bool, optional (default: False)
Print simulation info
Returns
-------
np.ndarray or list of np.ndarray
Trajectory with perturbations for a scalar cell_idx,
or a list of trajectories for a list input.
"""
single = isinstance(cell_idx, (int, np.integer))
indices = [int(cell_idx)] if single else list(cell_idx)
genes = get_genes_used(adata)
gene_names = adata.var.index[genes]
X_all = to_numpy(get_matrix(adata, spliced_key, genes=genes))
# Parse perturbations once for all cells
all_indices, all_values = _parse_perturb_genes(
gene_names, gene_perturbations, validate_non_negative=True
)
fixed_indices = all_indices if (not residual_gene_dynamics and len(all_indices) > 0) else None
fixed_values = all_values if fixed_indices is not None else None
solver = create_solver(
adata, cluster, degradation_key,
spliced_key=spliced_key,
x_max_percentile=x_max_percentile
)
solver.set_fixed_genes(fixed_indices, fixed_values)
if verbose:
print(f"Simulating perturbation for {len(indices)} cell(s) in cluster '{cluster}'")
print(f" Perturbations: {gene_perturbations}")
print(f" Perturbed genes: {'can evolve' if residual_gene_dynamics else 'held constant'}")
print(f" Method: {method}")
def _run_one(idx):
x0 = np.maximum(X_all[idx].copy(), 0)
if len(all_indices) > 0:
x0[all_indices] = all_values
return solver.solve(x0, t_span, method=method, clip_each_step=True)
results = _run_jobs(_run_one, indices, n_jobs)
return results[0] if single else results
def _simulate_cluster_gpu(
X_cluster: np.ndarray,
solver: 'ODESolver',
t_span: np.ndarray,
method: str,
device: str,
) -> np.ndarray:
"""
Integrate all cells in a cluster simultaneously on the GPU.
Implements the same Hopfield ODE as ODESolver but processes the entire
cluster as a single batched tensor operation instead of a cell-by-cell
Python loop. Uses torchdiffeq.odeint when available; falls back to a
native torch Euler loop otherwise (still GPU-batched).
Parameters
----------
X_cluster : np.ndarray
Initial expression states, shape (n_cells, n_genes).
solver : ODESolver
Configured solver carrying W, I, gamma, threshold, exponent,
x_min, x_max, and fixed-gene info.
t_span : np.ndarray
Time points, shape (n_steps,).
method : str
Integration method name recognised by torchdiffeq
('euler', 'rk4', 'dopri5', …). If torchdiffeq is unavailable,
'euler' is handled by a native torch loop; other methods raise a
warning and also fall back to the torch Euler loop.
device : torch.device
GPU (or CPU) device to run the computation on.
Returns
-------
np.ndarray
Final expression states, shape (n_cells, n_genes), on CPU.
"""
import torch
dtype = torch.float32
x_min_v = float(solver.x_min)
# ── Parameters → tensors ─────────────────────────────────────────────────
W_t = torch.tensor(solver.W, dtype=dtype, device=device)
I_t = torch.tensor(solver.I, dtype=dtype, device=device)
gamma_t = torch.tensor(solver.gamma, dtype=dtype, device=device)
threshold_t = torch.tensor(solver.threshold, dtype=dtype, device=device)
exponent_t = torch.tensor(solver.exponent, dtype=dtype, device=device)
x_max_t = (
torch.tensor(solver.x_max, dtype=dtype, device=device)
if solver.x_max is not None else None
)
# ── Initial states ────────────────────────────────────────────────────────
X0 = torch.tensor(np.maximum(X_cluster, x_min_v), dtype=dtype, device=device)
fixed_indices = solver.fixed_indices
fixed_values = solver.fixed_values
fixed_values_t = None
fixed_mask_t = None
if fixed_indices is not None and len(fixed_indices) > 0:
fixed_values_t = torch.tensor(fixed_values, dtype=dtype, device=device)
X0[:, fixed_indices] = fixed_values_t
fixed_mask_t = torch.zeros(solver.W.shape[0], dtype=torch.bool, device=device)
fixed_mask_t[fixed_indices] = True
# ── Batched Hill-sigmoid ODE ──────────────────────────────────────────────
# Hill sigmoid: x^n / (x^n + s^n) — matches sigmoid() in _utils/math.py
def dynamics(t, x):
# x: (n_cells, n_genes)
x_c = x.clamp(min=x_min_v)
if x_max_t is not None:
x_c = x_c.clamp(max=x_max_t)
x_pos = x_c.clamp(min=1e-12) # avoid 0^n for fractional n
xn = x_pos ** exponent_t # (n_cells, n_genes)
sn = threshold_t ** exponent_t # (n_genes,) — broadcast
sig = xn / (xn + sn) # Hill sigmoid
dxdt = sig @ W_t.T - gamma_t * x_c + I_t # (n_cells, n_genes)
# Soft lower boundary: don't push below x_min
dxdt = torch.where(x <= x_min_v, dxdt.clamp(min=0.0), dxdt)
# Soft upper boundary
if x_max_t is not None:
dxdt = torch.where(x >= x_max_t, dxdt.clamp(max=0.0), dxdt)
# Fixed genes: zero derivative so they stay constant
if fixed_mask_t is not None:
dxdt = dxdt.clone()
dxdt[:, fixed_mask_t] = 0.0
return dxdt
# ── Integration ───────────────────────────────────────────────────────────
with torch.no_grad():
try:
import torchdiffeq
_have_tde = True
except ImportError:
_have_tde = False
if _have_tde and method in _TORCHDIFFEQ_METHODS:
t_tensor = torch.tensor(t_span, dtype=dtype, device=device)
trajectory = torchdiffeq.odeint(dynamics, X0, t_tensor, method=method)
# trajectory: (n_steps, n_cells, n_genes) → take last time point
X_final = trajectory[-1]
else:
# Native torch Euler loop — no torchdiffeq dependency required.
if method != 'euler':
import warnings
warnings.warn(
f"torchdiffeq not available or method '{method}' is not in "
f"_TORCHDIFFEQ_METHODS; falling back to torch Euler on {device}.",
UserWarning,
)
x = X0.clone()
for i in range(1, len(t_span)):
dt_step = float(t_span[i] - t_span[i - 1])
x = x + dt_step * dynamics(None, x)
x = x.clamp(min=x_min_v)
if x_max_t is not None:
x = x.clamp(max=x_max_t)
if fixed_indices is not None and len(fixed_indices) > 0:
x[:, fixed_indices] = fixed_values_t
X_final = x
return X_final.cpu().numpy()
def simulate_shift_ode(
adata: 'AnnData',
perturb_condition: Dict[str, float],
cluster_key: str,
dt: float = 5.0,
n_steps: int = 100,
spliced_key: str = 'Ms',
degradation_key: str = 'gamma',
method: str = 'euler',
use_cluster_specific_GRN: bool = True,
x_max_percentile: float = 99.0,
residual_gene_dynamics: bool = False,
n_jobs: int = -1,
device: Optional[str] = None,
verbose: bool = False
) -> 'AnnData':
"""
Simulate dataset-wide trajectory shifts with gene perturbations using ODE integration.
This function mimics the propagation-based `simulate_shift` but uses continuous
ODE integration. It calculates the final state for every cell after a time `dt`
under the perturbed conditions, and stores the resulting shift (delta_X).
When a CUDA GPU is available (and `method` is GPU-compatible), all cells in
each cluster are integrated simultaneously as a single batched tensor operation
via `_simulate_cluster_gpu`, which uses `torchdiffeq.odeint` when installed and
falls back to a native torch Euler loop otherwise. The result is always moved
back to CPU before being stored in the returned AnnData.
Parameters
----------
adata : AnnData
Annotated data object with fitted interactions
perturb_condition : dict
Dictionary mapping gene names to perturbation values (e.g., {"Gata1": 0.0}).
cluster_key : str
Key in adata.obs containing cluster assignments.
dt : float, optional (default: 5.0)
Total time duration to simulate the ODEs.
n_steps : int, optional (default: 100)
Number of time steps for the ODE solver.
spliced_key : str, optional (default: 'Ms')
Key for expression data.
degradation_key : str, optional (default: 'gamma')
Key for degradation rates.
method : str, optional (default: 'euler')
Integration method.
GPU-compatible (via torchdiffeq or native torch):
'euler', 'rk4', 'midpoint', 'dopri5', 'dopri8', 'bosh3',
'adaptive_heun', 'fehlberg2'
CPU-only (scipy):
'odeint', 'RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA'
use_cluster_specific_GRN : bool, optional (default: True)
If True, uses cluster-specific solvers. If False, uses a global solver.
x_max_percentile : float, optional (default: 99.0)
Percentile for upper bound. Prevents divergence.
residual_gene_dynamics : bool, optional (default: False)
If False, perturbed genes are held fixed. If True, they evolve.
n_jobs : int, optional (default: -1)
Number of parallel jobs for the CPU fallback cell loop.
Ignored when the GPU path is active.
device : str or None, optional (default: None)
Target device for GPU-batched integration.
None → auto-detect: use 'cuda' if available and method is GPU-compatible,
otherwise 'cpu'.
'cuda' → force GPU (raises if CUDA unavailable).
'cpu' → always use the CPU path (scipy/joblib, as before).
verbose : bool, optional (default: False)
Print simulation progress.
Returns
-------
AnnData
A copy of the input AnnData with 'simulated_count' and 'delta_X' added to layers.
All arrays are numpy (CPU) regardless of where the integration ran.
"""
import torch
# ── Resolve target device ─────────────────────────────────────────────────
if device == 'cpu':
use_gpu = False
torch_device = torch.device('cpu')
elif device == 'cuda' or device == 'mps':
if device == 'cuda' and not torch.cuda.is_available():
raise RuntimeError("device='cuda' requested but CUDA is not available.")
elif device == 'mps' and not torch.backends.mps.is_available():
raise RuntimeError("device='mps' requested but MPS is not available.")
use_gpu = True
torch_device = torch.device(device)
else: # None → auto-detect
if torch.cuda.is_available():
torch_device = torch.device('cuda')
elif torch.backends.mps.is_available():
torch_device = torch.device('mps')
else:
torch_device = torch.device('cpu')
use_gpu = torch_device.type != 'cpu' and (method in _TORCHDIFFEQ_METHODS)
if verbose:
backend = f"GPU ({torch_device})" if use_gpu else "CPU"
print(f"simulate_shift_ode: backend={backend}, method={method}")
adata_out = adata.copy()
# Identify used genes
genes_mask = get_genes_used(adata_out)
gene_names = adata_out.var_names[genes_mask]
# Get initial states
X_orig = to_numpy(get_matrix(adata_out, spliced_key, genes=genes_mask))
X_sim = np.zeros_like(X_orig)
V_sim = np.zeros_like(X_orig)
# Time span for the ODE simulation
t_span = np.linspace(0, dt, n_steps)
clusters = adata_out.obs[cluster_key].unique() if use_cluster_specific_GRN else [None]
for cluster in clusters:
if verbose:
print(f"Processing cluster: {cluster if cluster else 'Global'}")
if cluster is not None:
cell_indices = np.where(adata_out.obs[cluster_key] == cluster)[0]
else:
cell_indices = np.arange(adata_out.n_obs)
if len(cell_indices) == 0:
continue
# Create solver for this cluster
solver = create_solver(
adata_out, cluster, degradation_key,
spliced_key=spliced_key,
x_max_percentile=x_max_percentile
)
# Configure fixed genes based on perturbations
all_indices, all_values = _parse_perturb_genes(
gene_names, perturb_condition, validate_non_negative=True
)
if not residual_gene_dynamics and len(all_indices) > 0:
solver.set_fixed_genes(all_indices, all_values)
else:
solver.set_fixed_genes(None, None)
if use_gpu:
# ── GPU path: integrate all cells in the cluster as one batch ────
X_cluster = X_orig[cell_indices]
try:
X_sim[cell_indices] = _simulate_cluster_gpu(
X_cluster, solver, t_span, method, torch_device
)
torch.cuda.empty_cache() # release allocator cache after each cluster
except torch.cuda.OutOfMemoryError:
torch.cuda.empty_cache() # flush before falling back to CPU
import warnings
warnings.warn(
f"GPU OOM on cluster '{cluster}'; falling back to CPU for this cluster.",
RuntimeWarning,
)
use_gpu = False # disable GPU for remaining clusters too
# fall through to CPU path below
if not use_gpu:
# ── CPU path: cell-by-cell with joblib threads ───────────────────
def _simulate_cell(x0_row):
x0 = np.maximum(x0_row, 0)
if len(all_indices) > 0:
x0[all_indices] = all_values
return solver.solve(x0, t_span, method=method, clip_each_step=True)[-1]
desc = f"Cells in {cluster if cluster else 'global'}"
x0_list = [X_orig[idx].copy()
for idx in (tqdm(cell_indices, desc=desc) if verbose else cell_indices)]
results = _run_jobs(_simulate_cell, x0_list, n_jobs)
X_sim[cell_indices] = np.array(results)
# Velocity at final state (CPU numpy, vectorised over cells)
V_sim[cell_indices] = solver.dynamics_batch(X_sim[cell_indices], 0.0)
# Calculate shift (delta_X)
delta_X = X_sim - X_orig
# Map back to full gene array shape
n_cells, n_all_genes = adata_out.shape
delta_X_full = np.zeros((n_cells, n_all_genes), dtype=np.float32)
X_sim_full = np.zeros((n_cells, n_all_genes), dtype=np.float32)
V_sim_full = np.zeros((n_cells, n_all_genes), dtype=np.float32) # Placeholder for velocity
# Assuming genes_mask is a boolean mask or index array
delta_X_full[:, genes_mask] = delta_X
X_sim_full[:, genes_mask] = X_sim
V_sim_full[:, genes_mask] = V_sim # Store velocity
# Save to layers
adata_out.layers['simulated_count'] = X_sim_full
adata_out.layers['delta_X'] = delta_X_full
adata_out.layers['simulated_velocity'] = V_sim_full # Store velocity in layers
# Update scHopfield metadata dict
_update_scHopfield_uns(adata_out, perturb_condition=perturb_condition,
simulation_method='ODE', ode_dt=dt)
return adata_out
def calculate_trajectory_flow(
adata: AnnData,
wt_trajectories: Dict[str, np.ndarray],
perturbed_trajectories: Dict[str, np.ndarray],
cluster_key: str = 'cell_type',
basis: str = 'umap',
time_point: int = -1,
method: str = 'hopfield',
n_neighbors: int = 30,
n_jobs: int = 4,
verbose: bool = True,
) -> np.ndarray:
"""
Calculate perturbation flow from ODE trajectory simulation results.
Takes the final (or specified) time point from ODE trajectories and
computes the flow in embedding space.
Parameters
----------
adata : AnnData
Annotated data with cell information
wt_trajectories : dict
Dictionary mapping cluster -> WT trajectory (n_time, n_genes)
perturbed_trajectories : dict
Dictionary mapping cluster -> perturbed trajectory (n_time, n_genes)
cluster_key : str, optional (default: 'cell_type')
Key for cluster labels
basis : str, optional (default: 'umap')
Embedding basis
time_point : int, optional (default: -1)
Which time point to use (-1 for final)
method : str, optional (default: 'hopfield')
Flow calculation method:
- 'hopfield': Use Hopfield model velocity directly
- 'difference': Simple difference in gene space projected to embedding
n_neighbors : int, optional (default: 30)
Number of neighbors for projection
n_jobs : int, optional (default: 4)
Number of parallel jobs
verbose : bool, optional (default: True)
Print progress
Returns
-------
np.ndarray
Perturbation flow in embedding space (n_cells, 2)
"""
# Import here to avoid circular imports
from ..tools.velocity import compute_velocity
from ..tools.embedding import project_to_embedding
genes = get_genes_used(adata)
n_cells = adata.n_obs
n_genes = len(genes)
# Initialize arrays
delta_X = np.zeros((n_cells, n_genes), dtype=np.float32)
X_wt_final = np.zeros((n_cells, n_genes), dtype=np.float32)
X_pert_final = np.zeros((n_cells, n_genes), dtype=np.float32)
# Get final states from trajectories for each cluster
for cluster in wt_trajectories.keys():
if cluster not in perturbed_trajectories:
continue
mask = adata.obs[cluster_key] == cluster
if not mask.any():
continue
# Get final time point
wt_final = wt_trajectories[cluster][time_point]
pert_final = perturbed_trajectories[cluster][time_point]
# Assign to all cells in this cluster
cell_indices = np.where(mask)[0]
for idx in cell_indices:
X_wt_final[idx] = wt_final
X_pert_final[idx] = pert_final
delta_X[idx] = pert_final - wt_final
# Store delta_X
adata.layers['delta_X_ode'] = delta_X
if method == 'hopfield':
if verbose:
print("Computing Hopfield velocities...")
delta_velocity = np.zeros((n_cells, n_genes), dtype=np.float32)
for cluster in wt_trajectories.keys():
mask = adata.obs[cluster_key] == cluster
if not mask.any():
continue
cell_indices = np.where(mask)[0]
# Compute velocity at WT and perturbed states
v_wt = compute_velocity(adata, X=X_wt_final[mask], cluster=cluster)
v_pert = compute_velocity(adata, X=X_pert_final[mask], cluster=cluster)
delta_velocity[cell_indices] = v_pert - v_wt
if verbose:
print("Projecting to embedding...")
embedding_flow = project_to_embedding(
adata, delta_velocity, basis=basis,
n_neighbors=n_neighbors, n_jobs=n_jobs
)
else: # difference method
if verbose:
print("Projecting expression difference to embedding...")
embedding_flow = project_to_embedding(
adata, delta_X, basis=basis,
n_neighbors=n_neighbors, n_jobs=n_jobs
)
# Store results
adata.obsm[f'ode_perturbation_flow_{basis}'] = embedding_flow
adata.uns['ode_perturbation_flow_params'] = {
'basis': basis,
'method': method,
'time_point': time_point,
'clusters': list(wt_trajectories.keys())
}
if verbose:
print(f"ODE perturbation flow stored in adata.obsm['ode_perturbation_flow_{basis}']")
return embedding_flow