"""ODE solver for gene regulatory network dynamics."""
import numpy as np
from scipy.integrate import odeint, solve_ivp
from typing import Optional
from anndata import AnnData
from .._utils.math import sigmoid
from .._utils.io import get_genes_used
[docs]
class ODESolver:
"""
ODE solver for Hopfield network dynamics.
Solves: dx/dt = W * sigmoid(x) - gamma * x + I
With constraints to ensure expression values remain non-negative
and bounded to prevent divergence. Supports fixing certain genes
at constant values (e.g., for knockout/overexpression simulations).
"""
[docs]
def __init__(
self,
W: np.ndarray,
bias_vector: np.ndarray,
gamma: np.ndarray,
threshold: np.ndarray,
exponent: np.ndarray,
x_min: float = 0.0,
x_max: Optional[np.ndarray] = None,
fixed_indices: Optional[np.ndarray] = None,
fixed_values: Optional[np.ndarray] = None
):
"""
Initialize ODE solver.
Parameters
----------
W : np.ndarray
Interaction matrix
I : np.ndarray
External input
gamma : np.ndarray
Degradation rates
threshold : np.ndarray
Sigmoid threshold parameters
exponent : np.ndarray
Sigmoid exponent parameters
x_min : float, optional (default: 0.0)
Minimum expression value (non-negative constraint)
x_max : np.ndarray, optional
Maximum expression values per gene. If None, no upper bound.
fixed_indices : np.ndarray, optional
Indices of genes to keep fixed (e.g., perturbed genes)
fixed_values : np.ndarray, optional
Values to fix the genes at (must match length of fixed_indices)
"""
self.W = W
self.I = bias_vector
self.gamma = gamma
self.threshold = threshold
self.exponent = exponent
self.x_min = x_min
self.x_max = x_max
self.fixed_indices = fixed_indices
self.fixed_values = fixed_values
def set_fixed_genes(
self,
fixed_indices: Optional[np.ndarray],
fixed_values: Optional[np.ndarray]
) -> None:
"""
Set genes to be held fixed during simulation.
Parameters
----------
fixed_indices : np.ndarray
Indices of genes to keep fixed
fixed_values : np.ndarray
Values to fix the genes at
"""
self.fixed_indices = fixed_indices
self.fixed_values = fixed_values
def _clip(self, x: np.ndarray) -> np.ndarray:
x = np.maximum(x, self.x_min)
if self.x_max is not None:
x = np.minimum(x, self.x_max)
return x
def _clip_trajectory(self, traj: np.ndarray) -> np.ndarray:
traj = np.maximum(traj, self.x_min)
if self.x_max is not None:
traj = np.minimum(traj, self.x_max)
return traj
def _enforce_fixed(self, x: np.ndarray) -> None:
"""Overwrite fixed-gene positions (in-place, 1-D)."""
if self.fixed_indices is not None and len(self.fixed_indices) > 0:
x[self.fixed_indices] = self.fixed_values
def _enforce_fixed_trajectory(self, traj: np.ndarray) -> None:
"""Overwrite fixed-gene columns (in-place, 2-D n_times×n_genes)."""
if self.fixed_indices is not None and len(self.fixed_indices) > 0:
traj[:, self.fixed_indices] = self.fixed_values
def dynamics(self, x: np.ndarray, t: float) -> np.ndarray:
"""Compute dx/dt with soft boundary enforcement."""
# Clip x to valid range before computing dynamics
x_clipped = self._clip(x.copy())
sig = sigmoid(x_clipped, self.threshold, self.exponent)
dxdt = self.W @ sig - self.gamma * x_clipped + self.I
# Soft boundary: if x is at lower bound, don't let it go more negative
at_lower = x <= self.x_min
dxdt[at_lower] = np.maximum(dxdt[at_lower], 0)
# If x is at upper bound, don't let it go more positive
if self.x_max is not None:
at_upper = x >= self.x_max
dxdt[at_upper] = np.minimum(dxdt[at_upper], 0)
# Fixed genes have zero derivative (they don't change)
if self.fixed_indices is not None and len(self.fixed_indices) > 0:
dxdt[self.fixed_indices] = 0.0
return dxdt
def dynamics_ivp(self, t: float, x: np.ndarray) -> np.ndarray:
"""Compute dx/dt for solve_ivp (arguments reversed)."""
return self.dynamics(x, t)
def dynamics_batch(self, X: np.ndarray, t: float) -> np.ndarray:
"""Compute dx/dt for a batch of states (n_cells, n_genes).
Vectorized equivalent of dynamics(). Uses sig @ W.T instead of W @ sig
to handle the 2-D case correctly.
"""
X_clipped = np.maximum(X, self.x_min)
if self.x_max is not None:
X_clipped = np.minimum(X_clipped, self.x_max)
sig = sigmoid(X_clipped, self.threshold, self.exponent) # (n_cells, n_genes)
dxdt = sig @ self.W.T - self.gamma * X_clipped + self.I # (n_cells, n_genes)
at_lower = X <= self.x_min
dxdt[at_lower] = np.maximum(dxdt[at_lower], 0)
if self.x_max is not None:
at_upper = X >= self.x_max
dxdt[at_upper] = np.minimum(dxdt[at_upper], 0)
if self.fixed_indices is not None and len(self.fixed_indices) > 0:
dxdt[:, self.fixed_indices] = 0.0
return dxdt
def solve(
self,
x0: np.ndarray,
t_span: np.ndarray,
method: str = 'euler',
clip_each_step: bool = True
) -> np.ndarray:
"""
Solve ODE from initial condition x0.
Parameters
----------
x0 : np.ndarray
Initial condition (must be non-negative)
t_span : np.ndarray
Time points
method : str, optional (default: 'euler')
Integration method:
- 'euler': Simple Euler method with clipping (stable, recommended)
- 'odeint': scipy.integrate.odeint (may diverge)
- 'RK45': scipy.integrate.solve_ivp with RK45
clip_each_step : bool, optional (default: True)
Whether to clip values at each step (prevents divergence)
Returns
-------
np.ndarray
Solution trajectory (len(t_span) × n_genes)
"""
# Ensure initial condition is valid
x0 = self._clip(x0)
self._enforce_fixed(x0)
if method == 'euler':
return self._solve_euler(x0, t_span, clip_each_step)
elif method == 'odeint':
trajectory = odeint(self.dynamics, x0, t_span)
if clip_each_step:
trajectory = self._clip_trajectory(trajectory)
self._enforce_fixed_trajectory(trajectory)
return trajectory
elif method in ['RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA']:
return self._solve_ivp(x0, t_span, method, clip_each_step)
else:
raise ValueError(f"Unknown method: {method}. Use 'euler', 'odeint', or scipy method names.")
def _solve_euler(
self,
x0: np.ndarray,
t_span: np.ndarray,
clip_each_step: bool = True
) -> np.ndarray:
"""
Solve ODE using Euler method with clipping at each step.
This is more stable for stiff systems and ensures non-negativity.
Fixed genes (if any) are held constant throughout the simulation.
"""
n_steps = len(t_span)
n_genes = len(x0)
trajectory = np.zeros((n_steps, n_genes), dtype=np.float32)
trajectory[0] = x0.copy()
x = x0.copy()
for i in range(1, n_steps):
dt = t_span[i] - t_span[i-1]
# Compute derivative
dxdt = self.dynamics(x, t_span[i-1])
# Euler step
x = x + dt * dxdt
# Clip to valid range and enforce fixed genes
if clip_each_step:
x = self._clip(x)
self._enforce_fixed(x)
trajectory[i] = x
return trajectory
def _solve_ivp(
self,
x0: np.ndarray,
t_span: np.ndarray,
method: str,
clip_each_step: bool
) -> np.ndarray:
"""Solve using scipy solve_ivp."""
sol = solve_ivp(
self.dynamics_ivp,
(t_span[0], t_span[-1]),
x0,
method=method,
t_eval=t_span,
dense_output=False
)
trajectory = sol.y.T # Transpose to (n_times, n_genes)
if clip_each_step:
trajectory = self._clip_trajectory(trajectory)
self._enforce_fixed_trajectory(trajectory)
return trajectory
[docs]
def create_solver(
adata: AnnData,
cluster: str,
degradation_key: str = 'gamma',
spliced_key: Optional[str] = None,
x_max_percentile: float = 99.0
) -> ODESolver:
"""
Create ODE solver for a specific cluster.
Parameters
----------
adata : AnnData
Annotated data object with fitted interactions
cluster : str
Cluster name
degradation_key : str, optional
Key for degradation rates
spliced_key : str, optional
Key for expression data to compute bounds. If None, uses scHopfield default.
x_max_percentile : float, optional (default: 99.0)
Percentile of expression values to use as upper bound.
Set to None to disable upper bound.
Returns
-------
ODESolver
Configured ODE solver with bounds
"""
from .._utils.io import get_matrix, to_numpy
from ._utils import _get_W_matrix, _compute_x_bounds
genes = get_genes_used(adata)
W = _get_W_matrix(adata, cluster, use_cluster_specific=True)
bias_vector = adata.var[f'I_{cluster}'].values[genes]
gamma_key = f'gamma_{cluster}'
gamma = adata.var[gamma_key].values[genes] if gamma_key in adata.var else adata.var[degradation_key].values[genes]
threshold = adata.var['sigmoid_threshold'].values[genes]
exponent = adata.var['sigmoid_exponent'].values[genes]
# Compute upper bounds from data
if x_max_percentile is not None:
if spliced_key is None:
spliced_key = adata.uns.get('scHopfield', {}).get('spliced_key', 'Ms')
X = to_numpy(get_matrix(adata, spliced_key, genes=genes))
_, x_max = _compute_x_bounds(X, x_max_percentile, multiplier=2.0)
else:
x_max = None
return ODESolver(W, bias_vector, gamma, threshold, exponent, x_min=0.0, x_max=x_max)